@@ -47,11 +47,26 @@ def get_data_split(
4747 local_param_dict ["resample" ] = None
4848 logger .warning ("Input data is not 2D, overriding resample strategy to None." )
4949
50- # No resampling
50+ # --- SAFEGUARD for Stratification ---
51+ # Check if any class has fewer than 2 samples for stratified splitting
52+ class_counts = y .value_counts ()
53+ min_class_count = class_counts .min ()
54+ use_stratify = min_class_count >= 2
55+
56+ if not use_stratify :
57+ logger .warning (
58+ f"Cannot use stratified split: smallest class has only { min_class_count } sample(s). "
59+ f"Class distribution: { class_counts .to_dict ()} . Using random split instead."
60+ )
61+ # Also disable resampling since we can't properly balance with so few samples
62+ if local_param_dict .get ("resample" ) is not None :
63+ logger .warning ("Disabling resampling due to insufficient samples in minority class." )
64+ local_param_dict ["resample" ] = None
65+
5166 # First, split into a preliminary training set and a final hold-out test set.
5267 # This original test set will NOT be resampled.
5368 X_train_orig , X_test_orig , y_train_orig , y_test_orig = train_test_split (
54- X , y , test_size = 0.25 , random_state = 1 , stratify = y
69+ X , y , test_size = 0.25 , random_state = 1 , stratify = y if use_stratify else None
5570 )
5671
5772 # --- SAFEGUARD for Resampling ---
@@ -60,7 +75,8 @@ def get_data_split(
6075 minority_class_count = y_train_orig .value_counts ().min ()
6176 if minority_class_count < 2 and local_param_dict .get ("resample" ) is not None :
6277 logger .warning (
63- f"Minority class has only { minority_class_count } sample(s) in the training fold. Disabling resampling to prevent errors."
78+ f"Minority class has only { minority_class_count } sample(s) in the training fold. "
79+ f"Disabling resampling to prevent errors."
6480 )
6581 local_param_dict ["resample" ] = None
6682
@@ -85,9 +101,6 @@ def get_data_split(
85101 y_name = y_train_orig .name
86102
87103 # Oversample training set
88- # --- CRITICAL FIX: Use 'auto' for sampling_strategy ---
89- # 'auto' is equivalent to 'minority' and correctly handles cases where
90- # the data is already balanced, preventing a ValueError.
91104 ros = RandomOverSampler (sampling_strategy = "auto" , random_state = 1 )
92105 X_train_res , y_train_res = ros .fit_resample (X_train_orig , y_train_orig )
93106
@@ -99,14 +112,25 @@ def get_data_split(
99112 X_train_processed = X_train_orig
100113 y_train_processed = y_train_orig
101114
115+ # Check again if we can stratify the second split
116+ train_class_counts = y_train_processed .value_counts ()
117+ min_train_class_count = train_class_counts .min ()
118+ use_stratify_second = min_train_class_count >= 2
119+
120+ if not use_stratify_second :
121+ logger .warning (
122+ f"Cannot use stratified split for train/validation: smallest class has only "
123+ f"{ min_train_class_count } sample(s). Using random split instead."
124+ )
125+
102126 # Finally, split the (potentially resampled) training data into the final
103127 # training and validation sets for the grid search.
104128 X_train , X_test , y_train , y_test = train_test_split (
105129 X_train_processed ,
106130 y_train_processed ,
107131 test_size = 0.25 ,
108132 random_state = 1 ,
109- stratify = y_train_processed ,
133+ stratify = y_train_processed if use_stratify_second else None ,
110134 )
111135
112136 return X_train , X_test , y_train , y_test , X_test_orig , y_test_orig
@@ -137,4 +161,4 @@ def is_valid_shape(input_data: Union[np.ndarray, pd.DataFrame]) -> bool:
137161
138162 else :
139163 # Input data is neither a numpy array nor a pandas DataFrame
140- return False
164+ return False
0 commit comments