processing#
- class datacheese.processing.KFoldCrossValidation(data, k, randomize=True, seed=None)#
Bases:
object
K-fold cross validation iterator class.
- Parameters:
data (numpy.ndarray) – Array to split into k-folds. Splitting is always done on axis 0.
k (int) – Number of folds.
randomize (bool, default True) – Whether or not to shuffle the data before splitting.
seed (int or None, default None) – Random seed used to shuffle the data.
Examples
>>> import numpy as np >>> from datacheese.processing import KFoldCrossValidation
Generate data:
>>> X = np.arange(12).reshape(6, 2) >>> X array([[ 0, 1], [ 2, 3], [ 4, 5], [ 6, 7], [ 8, 9], [10, 11]])
Split into 3 folds and iterate over them:
>>> for i, (train_data, test_data) in enumerate( ... KFoldCrossValidation(X, k=3) ... ): ... print(f'Fold {i}') ... print('Train Data:') ... print(train_data) ... print('Test Data:') ... print(test_data) Fold 0 Train Data: [[ 4 5] [ 2 3] [ 8 9] [10 11]] Test Data: [[6 7] [0 1]] Fold 1 Train Data: [[ 6 7] [ 0 1] [ 8 9] [10 11]] Test Data: [[4 5] [2 3]] Fold 2 Train Data: [[6 7] [0 1] [4 5] [2 3]] Test Data: [[ 8 9] [10 11]]