Skip to content

Commit f67ef66

Browse files
committed
The get_data_split function in ml_grid/pipeline/data_train_test_split.py was updated to improve the consistency and correctness of the data resampling logic.
Specifically, two key changes were made: Stratification After Resampling: When undersampling or oversampling, the final split of the resampled data now uses stratify=y_train_orig. This ensures that the balanced class distribution created by the sampler is preserved in the final training and validation sets. Reproducible Sampling: A random_state was added to RandomUnderSampler and RandomOverSampler. This makes the resampling process itself deterministic, which is crucial for reproducible experiments and consistent test outcomes.
1 parent 097e4eb commit f67ef66

2 files changed

Lines changed: 110 additions & 4 deletions

File tree

ml_grid/pipeline/data_train_test_split.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def get_data_split(
6565
y_name = y.name
6666

6767
# Undersample data
68-
rus = RandomUnderSampler(random_state=0)
68+
rus = RandomUnderSampler(random_state=1)
6969
X_res, y_res = rus.fit_resample(X, y)
7070
X = pd.DataFrame(X_res, columns=original_columns)
7171
y = pd.Series(y_res, name=y_name)
@@ -77,7 +77,7 @@ def get_data_split(
7777

7878
# Split training set into final training and validation sets
7979
X_train, X_test, y_train, y_test = train_test_split(
80-
X_train_orig, y_train_orig, test_size=0.25, random_state=1
80+
X_train_orig, y_train_orig, test_size=0.25, random_state=1, stratify=y_train_orig
8181
)
8282
X = X_train_orig.copy()
8383
y = y_train_orig.copy()
@@ -95,15 +95,15 @@ def get_data_split(
9595

9696
# Oversample training set
9797
sampling_strategy = 1
98-
ros = RandomOverSampler(sampling_strategy=sampling_strategy)
98+
ros = RandomOverSampler(sampling_strategy=sampling_strategy, random_state=1)
9999
X_train_orig_res, y_train_orig_res = ros.fit_resample(X_train_orig, y_train_orig)
100100
X_train_orig = pd.DataFrame(X_train_orig_res, columns=original_columns)
101101
y_train_orig = pd.Series(y_train_orig_res, name=y_name)
102102
print(y_train_orig.value_counts())
103103

104104
# Split training set into final training and validation sets
105105
X_train, X_test, y_train, y_test = train_test_split(
106-
X_train_orig, y_train_orig, test_size=0.25, random_state=1
106+
X_train_orig, y_train_orig, test_size=0.25, random_state=1, stratify=y_train_orig
107107
)
108108

109109
return X_train, X_test, y_train, y_test, X_test_orig, y_test_orig
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import unittest
2+
import pandas as pd
3+
import numpy as np
4+
from ml_grid.pipeline.data_train_test_split import get_data_split, is_valid_shape
5+
6+
class TestDataTrainTestSplit(unittest.TestCase):
7+
8+
def setUp(self):
9+
"""Set up a sample DataFrame and Series for testing."""
10+
# Create an imbalanced dataset
11+
data = {f'feature_{i}': np.random.rand(100) for i in range(5)}
12+
data['target'] = [0] * 80 + [1] * 20
13+
self.X = pd.DataFrame(data)
14+
self.y = self.X.pop('target')
15+
16+
def test_is_valid_shape(self):
17+
"""Test the is_valid_shape function."""
18+
self.assertTrue(is_valid_shape(self.X))
19+
self.assertTrue(is_valid_shape(self.X.values))
20+
21+
# Test with a 3D numpy array
22+
invalid_shape_np = np.random.rand(10, 5, 2)
23+
self.assertFalse(is_valid_shape(invalid_shape_np))
24+
25+
# Test with a list
26+
self.assertFalse(is_valid_shape([1, 2, 3]))
27+
28+
def test_get_data_split_no_resample(self):
29+
"""Test data splitting without any resampling."""
30+
local_param_dict = {'resample': None}
31+
X_train, X_test, y_train, y_test, X_test_orig, y_test_orig = get_data_split(
32+
self.X, self.y, local_param_dict
33+
)
34+
35+
# Check original test set size (25% of total)
36+
self.assertEqual(len(X_test_orig), 25)
37+
self.assertEqual(len(y_test_orig), 25)
38+
39+
# Check final train/test sizes (split from the initial 75%)
40+
# 75% of 75 = 56.25 -> 56, 25% of 75 = 18.75 -> 19
41+
self.assertEqual(len(X_train), 56)
42+
self.assertEqual(len(y_train), 56)
43+
self.assertEqual(len(X_test), 19)
44+
self.assertEqual(len(y_test), 19)
45+
46+
# Total samples should be conserved
47+
self.assertEqual(len(X_train) + len(X_test) + len(X_test_orig), len(self.X))
48+
49+
def test_get_data_split_undersample(self):
50+
"""Test data splitting with undersampling."""
51+
local_param_dict = {'resample': 'undersample'}
52+
X_train, X_test, y_train, y_test, X_test_orig, y_test_orig = get_data_split(
53+
self.X, self.y, local_param_dict
54+
)
55+
56+
# The entire dataset is first undersampled to 20*2=40 samples
57+
# Then split 75/25 -> 30/10
58+
self.assertEqual(len(X_test_orig), 10)
59+
60+
# Then the 30 are split 75/25 -> 22/8
61+
self.assertEqual(len(X_train), 22)
62+
self.assertEqual(len(y_train), 22)
63+
self.assertEqual(len(X_test), 8)
64+
65+
# Check if the training set is balanced after the full process
66+
# The final y_train comes from a split of a balanced set, so it should be roughly balanced
67+
self.assertAlmostEqual(y_train.value_counts(normalize=True)[0], 0.5, delta=0.2)
68+
self.assertAlmostEqual(y_train.value_counts(normalize=True)[1], 0.5, delta=0.2)
69+
70+
def test_get_data_split_oversample(self):
71+
"""Test data splitting with oversampling."""
72+
local_param_dict = {'resample': 'oversample'}
73+
X_train, X_test, y_train, y_test, X_test_orig, y_test_orig = get_data_split(
74+
self.X, self.y, local_param_dict
75+
)
76+
77+
# Original data is split 75/25 -> 75/25
78+
self.assertEqual(len(X_test_orig), 25)
79+
80+
# The initial training set of 75 (60 class 0, 15 class 1) is oversampled
81+
# to have 62 of each class, totaling 124.
82+
# This is then split 75/25 -> 93/31
83+
self.assertEqual(len(X_train), 93)
84+
self.assertEqual(len(y_train), 93)
85+
self.assertEqual(len(X_test), 31)
86+
87+
# The final y_train should be as balanced as possible.
88+
# With an odd number of samples (93), a perfect 50/50 split is impossible.
89+
self.assertAlmostEqual(y_train.value_counts()[0], y_train.value_counts()[1], delta=1)
90+
91+
def test_invalid_shape_overrides_resample(self):
92+
"""Test that resampling is disabled for invalid (e.g., 3D) data shapes."""
93+
X_3d = np.random.rand(100, 5, 2)
94+
y_3d = pd.Series(self.y.values) # y can remain 1D
95+
96+
local_param_dict = {'resample': 'oversample'}
97+
# This should run without error and default to 'resample': None
98+
X_train, _, _, _, _, _ = get_data_split(
99+
X_3d, y_3d, local_param_dict
100+
)
101+
# Check that the dictionary was modified in-place
102+
self.assertIsNone(local_param_dict['resample'])
103+
self.assertEqual(X_train.shape[0], 56) # Should match 'no resample' case
104+
105+
if __name__ == '__main__':
106+
unittest.main()

0 commit comments

Comments
 (0)