Skip to content

Commit c64fe9d

Browse files
author
SamoraHunter
committed
fall back for single class training set after split
1 parent 55099b3 commit c64fe9d

2 files changed

Lines changed: 30 additions & 0 deletions

File tree

ml_grid/pipeline/data.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,12 @@ def _split_data(self):
474474
self.y_test_orig,
475475
) = get_data_split(X=self.X, y=self.y, local_param_dict=self.local_param_dict)
476476

477+
if self.y_train.nunique() < 2:
478+
self.logger.warning(
479+
f"Training data contains only {self.y_train.nunique()} class(es). "
480+
"Model fitting may fail if the algorithm requires at least 2 classes."
481+
)
482+
477483
self._assert_index_alignment(self.X_train, self.y_train, "After get_data_split")
478484

479485
# --- CRITICAL FIX: Reset all indices immediately after splitting ---

ml_grid/pipeline/data_train_test_split.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,30 @@ def get_data_split(
135135
stratify=y_train_processed if use_stratify_second else None,
136136
)
137137

138+
# --- Fallback for single-class training set ---
139+
# If the random split resulted in a training set with only 1 class (but we had 2+ available),
140+
# we attempt to move a sample from the test set to the training set to prevent model failure.
141+
if y_train.nunique() < 2 and y_train_processed.nunique() >= 2:
142+
logger.warning(
143+
"y_train contains only 1 class after split. Attempting to move a sample from X_test to X_train to ensure class presence."
144+
)
145+
missing_classes = set(y_train_processed.unique()) - set(y_train.unique())
146+
for missing_cls in missing_classes:
147+
# Find candidates in test set
148+
candidates = y_test[y_test == missing_cls]
149+
if not candidates.empty:
150+
idx_to_move = candidates.index[0]
151+
152+
# Move from test to train
153+
X_train = pd.concat([X_train, X_test.loc[[idx_to_move]]])
154+
y_train = pd.concat([y_train, y_test.loc[[idx_to_move]]])
155+
156+
X_test = X_test.drop(idx_to_move)
157+
y_test = y_test.drop(idx_to_move)
158+
159+
logger.info(f"Moved sample {idx_to_move} (class {missing_cls}) from test to train.")
160+
break # Only need one sample to satisfy "at least 2 classes"
161+
138162
return X_train, X_test, y_train, y_test, X_test_orig, y_test_orig
139163

140164

0 commit comments

Comments
 (0)