1+ import unittest
2+ import pandas as pd
3+ import numpy as np
4+ from ml_grid .pipeline .data_train_test_split import get_data_split , is_valid_shape
5+
6+ class TestDataTrainTestSplit (unittest .TestCase ):
7+
8+ def setUp (self ):
9+ """Set up a sample DataFrame and Series for testing."""
10+ # Create an imbalanced dataset
11+ data = {f'feature_{ i } ' : np .random .rand (100 ) for i in range (5 )}
12+ data ['target' ] = [0 ] * 80 + [1 ] * 20
13+ self .X = pd .DataFrame (data )
14+ self .y = self .X .pop ('target' )
15+
16+ def test_is_valid_shape (self ):
17+ """Test the is_valid_shape function."""
18+ self .assertTrue (is_valid_shape (self .X ))
19+ self .assertTrue (is_valid_shape (self .X .values ))
20+
21+ # Test with a 3D numpy array
22+ invalid_shape_np = np .random .rand (10 , 5 , 2 )
23+ self .assertFalse (is_valid_shape (invalid_shape_np ))
24+
25+ # Test with a list
26+ self .assertFalse (is_valid_shape ([1 , 2 , 3 ]))
27+
28+ def test_get_data_split_no_resample (self ):
29+ """Test data splitting without any resampling."""
30+ local_param_dict = {'resample' : None }
31+ X_train , X_test , y_train , y_test , X_test_orig , y_test_orig = get_data_split (
32+ self .X , self .y , local_param_dict
33+ )
34+
35+ # Check original test set size (25% of total)
36+ self .assertEqual (len (X_test_orig ), 25 )
37+ self .assertEqual (len (y_test_orig ), 25 )
38+
39+ # Check final train/test sizes (split from the initial 75%)
40+ # 75% of 75 = 56.25 -> 56, 25% of 75 = 18.75 -> 19
41+ self .assertEqual (len (X_train ), 56 )
42+ self .assertEqual (len (y_train ), 56 )
43+ self .assertEqual (len (X_test ), 19 )
44+ self .assertEqual (len (y_test ), 19 )
45+
46+ # Total samples should be conserved
47+ self .assertEqual (len (X_train ) + len (X_test ) + len (X_test_orig ), len (self .X ))
48+
49+ def test_get_data_split_undersample (self ):
50+ """Test data splitting with undersampling."""
51+ local_param_dict = {'resample' : 'undersample' }
52+ X_train , X_test , y_train , y_test , X_test_orig , y_test_orig = get_data_split (
53+ self .X , self .y , local_param_dict
54+ )
55+
56+ # The entire dataset is first undersampled to 20*2=40 samples
57+ # Then split 75/25 -> 30/10
58+ self .assertEqual (len (X_test_orig ), 10 )
59+
60+ # Then the 30 are split 75/25 -> 22/8
61+ self .assertEqual (len (X_train ), 22 )
62+ self .assertEqual (len (y_train ), 22 )
63+ self .assertEqual (len (X_test ), 8 )
64+
65+ # Check if the training set is balanced after the full process
66+ # The final y_train comes from a split of a balanced set, so it should be roughly balanced
67+ self .assertAlmostEqual (y_train .value_counts (normalize = True )[0 ], 0.5 , delta = 0.2 )
68+ self .assertAlmostEqual (y_train .value_counts (normalize = True )[1 ], 0.5 , delta = 0.2 )
69+
70+ def test_get_data_split_oversample (self ):
71+ """Test data splitting with oversampling."""
72+ local_param_dict = {'resample' : 'oversample' }
73+ X_train , X_test , y_train , y_test , X_test_orig , y_test_orig = get_data_split (
74+ self .X , self .y , local_param_dict
75+ )
76+
77+ # Original data is split 75/25 -> 75/25
78+ self .assertEqual (len (X_test_orig ), 25 )
79+
80+ # The initial training set of 75 (60 class 0, 15 class 1) is oversampled
81+ # to have 62 of each class, totaling 124.
82+ # This is then split 75/25 -> 93/31
83+ self .assertEqual (len (X_train ), 93 )
84+ self .assertEqual (len (y_train ), 93 )
85+ self .assertEqual (len (X_test ), 31 )
86+
87+ # The final y_train should be as balanced as possible.
88+ # With an odd number of samples (93), a perfect 50/50 split is impossible.
89+ self .assertAlmostEqual (y_train .value_counts ()[0 ], y_train .value_counts ()[1 ], delta = 1 )
90+
91+ def test_invalid_shape_overrides_resample (self ):
92+ """Test that resampling is disabled for invalid (e.g., 3D) data shapes."""
93+ X_3d = np .random .rand (100 , 5 , 2 )
94+ y_3d = pd .Series (self .y .values ) # y can remain 1D
95+
96+ local_param_dict = {'resample' : 'oversample' }
97+ # This should run without error and default to 'resample': None
98+ X_train , _ , _ , _ , _ , _ = get_data_split (
99+ X_3d , y_3d , local_param_dict
100+ )
101+ # Check that the dictionary was modified in-place
102+ self .assertIsNone (local_param_dict ['resample' ])
103+ self .assertEqual (X_train .shape [0 ], 56 ) # Should match 'no resample' case
104+
105+ if __name__ == '__main__' :
106+ unittest .main ()
0 commit comments