processing#

class datacheese.processing.KFoldCrossValidation(data, k, randomize=True, seed=None)#

Bases: object

K-fold cross validation iterator class.

Parameters:
  • data (numpy.ndarray) – Array to split into k-folds. Splitting is always done on axis 0.

  • k (int) – Number of folds.

  • randomize (bool, default True) – Whether or not to shuffle the data before splitting.

  • seed (int or None, default None) – Random seed used to shuffle the data.

Examples

>>> import numpy as np
>>> from datacheese.processing import KFoldCrossValidation

Generate data:

>>> X = np.arange(12).reshape(6, 2)
>>> X
array([[ 0,  1],
       [ 2,  3],
       [ 4,  5],
       [ 6,  7],
       [ 8,  9],
       [10, 11]])

Split into 3 folds and iterate over them:

>>> for i, (train_data, test_data) in enumerate(
...     KFoldCrossValidation(X, k=3)
... ):
...     print(f'Fold {i}')
...     print('Train Data:')
...     print(train_data)
...     print('Test Data:')
...     print(test_data)
Fold 0
Train Data:
[[ 4  5]
 [ 2  3]
 [ 8  9]
 [10 11]]
Test Data:
[[6 7]
 [0 1]]
Fold 1
Train Data:
[[ 6  7]
 [ 0  1]
 [ 8  9]
 [10 11]]
Test Data:
[[4 5]
 [2 3]]
Fold 2
Train Data:
[[6 7]
 [0 1]
 [4 5]
 [2 3]]
Test Data:
[[ 8  9]
 [10 11]]