Skip to content

Commit 13a567b

Browse files
author
SamoraHunter
committed
statification fix
1 parent c6b014c commit 13a567b

1 file changed

Lines changed: 32 additions & 8 deletions

File tree

ml_grid/pipeline/data_train_test_split.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,26 @@ def get_data_split(
4747
local_param_dict["resample"] = None
4848
logger.warning("Input data is not 2D, overriding resample strategy to None.")
4949

50-
# No resampling
50+
# --- SAFEGUARD for Stratification ---
51+
# Check if any class has fewer than 2 samples for stratified splitting
52+
class_counts = y.value_counts()
53+
min_class_count = class_counts.min()
54+
use_stratify = min_class_count >= 2
55+
56+
if not use_stratify:
57+
logger.warning(
58+
f"Cannot use stratified split: smallest class has only {min_class_count} sample(s). "
59+
f"Class distribution: {class_counts.to_dict()}. Using random split instead."
60+
)
61+
# Also disable resampling since we can't properly balance with so few samples
62+
if local_param_dict.get("resample") is not None:
63+
logger.warning("Disabling resampling due to insufficient samples in minority class.")
64+
local_param_dict["resample"] = None
65+
5166
# First, split into a preliminary training set and a final hold-out test set.
5267
# This original test set will NOT be resampled.
5368
X_train_orig, X_test_orig, y_train_orig, y_test_orig = train_test_split(
54-
X, y, test_size=0.25, random_state=1, stratify=y
69+
X, y, test_size=0.25, random_state=1, stratify=y if use_stratify else None
5570
)
5671

5772
# --- SAFEGUARD for Resampling ---
@@ -60,7 +75,8 @@ def get_data_split(
6075
minority_class_count = y_train_orig.value_counts().min()
6176
if minority_class_count < 2 and local_param_dict.get("resample") is not None:
6277
logger.warning(
63-
f"Minority class has only {minority_class_count} sample(s) in the training fold. Disabling resampling to prevent errors."
78+
f"Minority class has only {minority_class_count} sample(s) in the training fold. "
79+
f"Disabling resampling to prevent errors."
6480
)
6581
local_param_dict["resample"] = None
6682

@@ -85,9 +101,6 @@ def get_data_split(
85101
y_name = y_train_orig.name
86102

87103
# Oversample training set
88-
# --- CRITICAL FIX: Use 'auto' for sampling_strategy ---
89-
# 'auto' is equivalent to 'minority' and correctly handles cases where
90-
# the data is already balanced, preventing a ValueError.
91104
ros = RandomOverSampler(sampling_strategy="auto", random_state=1)
92105
X_train_res, y_train_res = ros.fit_resample(X_train_orig, y_train_orig)
93106

@@ -99,14 +112,25 @@ def get_data_split(
99112
X_train_processed = X_train_orig
100113
y_train_processed = y_train_orig
101114

115+
# Check again if we can stratify the second split
116+
train_class_counts = y_train_processed.value_counts()
117+
min_train_class_count = train_class_counts.min()
118+
use_stratify_second = min_train_class_count >= 2
119+
120+
if not use_stratify_second:
121+
logger.warning(
122+
f"Cannot use stratified split for train/validation: smallest class has only "
123+
f"{min_train_class_count} sample(s). Using random split instead."
124+
)
125+
102126
# Finally, split the (potentially resampled) training data into the final
103127
# training and validation sets for the grid search.
104128
X_train, X_test, y_train, y_test = train_test_split(
105129
X_train_processed,
106130
y_train_processed,
107131
test_size=0.25,
108132
random_state=1,
109-
stratify=y_train_processed,
133+
stratify=y_train_processed if use_stratify_second else None,
110134
)
111135

112136
return X_train, X_test, y_train, y_test, X_test_orig, y_test_orig
@@ -137,4 +161,4 @@ def is_valid_shape(input_data: Union[np.ndarray, pd.DataFrame]) -> bool:
137161

138162
else:
139163
# Input data is neither a numpy array nor a pandas DataFrame
140-
return False
164+
return False

0 commit comments

Comments
 (0)