Skip to content

Commit b2766ba

Browse files
author
SamoraHunter
committed
force svc to scale data to handle persistent
1 parent b197dca commit b2766ba

1 file changed

Lines changed: 16 additions & 3 deletions

File tree

ml_grid/pipeline/grid_search_cross_validate.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from ml_grid.util.global_params import global_parameters
4343
from ml_grid.util.project_score_save import project_score_save_class
4444
from ml_grid.util.validate_parameters import validate_parameters_helper
45-
from sklearn.preprocessing import MinMaxScaler
45+
from sklearn.preprocessing import MinMaxScaler, StandardScaler
4646
from ml_grid.util.bayes_utils import is_skopt_space
4747
from skopt.space import Categorical
4848

@@ -176,8 +176,16 @@ def __init__(
176176
max_param_space_iter_value = self.ml_grid_object_iter.local_param_dict.get("max_param_space_iter_value")
177177

178178
if "svc" in method_name.lower():
179-
self.X_train = scale_data(self.X_train)
180-
self.X_test = scale_data(self.X_test)
179+
self.logger.info("Applying StandardScaler for SVC to prevent convergence issues.")
180+
scaler = StandardScaler()
181+
self.X_train = pd.DataFrame(
182+
scaler.fit_transform(self.X_train),
183+
columns=self.X_train.columns,
184+
index=self.X_train.index,
185+
)
186+
self.X_test = pd.DataFrame(
187+
scaler.transform(self.X_test), columns=self.X_test.columns, index=self.X_test.index
188+
)
181189

182190
# --- PERFORMANCE FIX for testing ---
183191
# Use a much faster CV strategy when in test_mode.
@@ -408,6 +416,11 @@ def __init__(
408416
current_algorithm = search.run_search(X_train_reset, y_train_reset)
409417

410418
except Exception as e:
419+
if "dual coefficients or intercepts are not finite" in str(e):
420+
self.logger.warning(f"SVC failed to fit due to data issues: {e}. Returning default score.")
421+
self.grid_search_cross_validate_score_result = 0.5
422+
return
423+
411424
# Log the error and re-raise it to stop the entire execution,
412425
# allowing the main loop in main.py to handle it based on error_raise.
413426
self.logger.error(

0 commit comments

Comments
 (0)