Skip to content

Commit ae2a661

Browse files
committed
Refactor data pipeline for robustness and fix parameter validation
ml_grid/pipeline/data.py: Add checks for pd.DataFrame before applying DataFrame-specific operations (cleaning, scaling, feature selection, embeddings) to support non-DataFrame inputs (e.g., Time Series). Improve _assert_index_alignment to handle non-pandas objects by checking length. Safely call reset_index only if the method exists on the object. ml_grid/pipeline/data_train_test_split.py: Handle numpy array inputs for X and y, converting to pandas objects where appropriate for splitting logic. Add check for DataFrame type before attempting to move samples in single-class fallback logic. ml_grid/pipeline/grid_search_cross_validate.py: Add automatic configuration of XLA_FLAGS for CUDA to resolve libdevice errors. Ensure FLAMLClassifierWrapper and AutoKerasClassifierWrapper receive DataFrames and run single-threaded during final CV. Remove logic that skipped final CV in test mode. ml_grid/util/validate_parameters.py: Update validation functions (validate_knn_parameters, validate_XGB_parameters) to handle lists of parameter dictionaries (Grid Search format) via recursion. Update type hints to support lists of dictionaries. ml_grid/pipeline/test_data_pipeline.py: Enable test_mode in test setup.
1 parent 2b1ffb6 commit ae2a661

5 files changed

Lines changed: 141 additions & 50 deletions

File tree

ml_grid/pipeline/data.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,12 @@
2323
from ml_grid.pipeline.embeddings import create_embedding_pipeline
2424
from ml_grid.util.global_params import global_parameters
2525
from ml_grid.util.logger_setup import setup_logger
26-
27-
warnings.filterwarnings("ignore", category=ConvergenceWarning)
28-
warnings.filterwarnings("ignore", category=UserWarning)
2926
from sklearn.preprocessing import (
3027
StandardScaler,
3128
) # Added explicit import for StandardScaler
3229

30+
warnings.filterwarnings("ignore", category=ConvergenceWarning)
31+
warnings.filterwarnings("ignore", category=UserWarning)
3332
warnings.filterwarnings("ignore", category=FutureWarning)
3433

3534

@@ -135,10 +134,20 @@ def _log_feature_transformation(
135134
}
136135
)
137136

138-
def _assert_index_alignment(
139-
self, df1: pd.DataFrame, df2: pd.Series, step_name: str
140-
):
137+
def _assert_index_alignment(self, df1: Any, df2: Any, step_name: str):
141138
"""Helper function to assert that DataFrame and Series indices are equal."""
139+
# Handle objects without .index (e.g. numpy arrays in time series mode)
140+
if not hasattr(df1, "index") or not hasattr(df2, "index"):
141+
if len(df1) != len(df2):
142+
self.logger.error(
143+
f"Length mismatch at {step_name}: {len(df1)} vs {len(df2)}"
144+
)
145+
raise AssertionError(f"Length mismatch at {step_name}")
146+
self.logger.debug(
147+
f"Length alignment PASSED at: {step_name} (non-pandas objects)"
148+
)
149+
return
150+
142151
try:
143152
assert_index_equal(df1.index, df2.index)
144153
self.logger.debug(f"Index alignment PASSED at: {step_name}")
@@ -499,18 +508,28 @@ def _split_data(self):
499508
# --- CRITICAL FIX: Reset all indices immediately after splitting ---
500509
# This ensures all downstream processing (constant removal, feature selection, embedding)
501510
# operates on data with clean, aligned, 0-based integer indices.
502-
self.X_train.reset_index(drop=True, inplace=True)
503-
self.y_train.reset_index(drop=True, inplace=True)
504-
self.X_test.reset_index(drop=True, inplace=True)
505-
self.y_test.reset_index(drop=True, inplace=True)
506-
self.X_test_orig.reset_index(drop=True, inplace=True)
507-
self.y_test_orig.reset_index(drop=True, inplace=True)
511+
if hasattr(self.X_train, "reset_index"):
512+
self.X_train.reset_index(drop=True, inplace=True)
513+
self.X_test.reset_index(drop=True, inplace=True)
514+
self.X_test_orig.reset_index(drop=True, inplace=True)
515+
516+
if hasattr(self.y_train, "reset_index"):
517+
self.y_train.reset_index(drop=True, inplace=True)
518+
self.y_test.reset_index(drop=True, inplace=True)
519+
self.y_test_orig.reset_index(drop=True, inplace=True)
520+
508521
self._assert_index_alignment(
509522
self.X_train, self.y_train, "After master reset_index"
510523
)
511524

512525
def _post_split_cleaning(self):
513526
"""Applies cleaning steps post-split to prevent data leakage."""
527+
if not isinstance(self.X_train, pd.DataFrame):
528+
self.logger.info(
529+
"Skipping post-split cleaning (not a DataFrame, likely Time Series mode)."
530+
)
531+
return
532+
514533
# Clean column names *before* dropping operations to ensure stable column order.
515534
cleanup = clean_up_class()
516535
cleanup.screen_non_float_types(self.X_train)
@@ -608,6 +627,12 @@ def _post_split_cleaning(self):
608627

609628
def _scale_features(self):
610629
"""Applies standard scaling to the feature sets."""
630+
if not isinstance(self.X_train, pd.DataFrame):
631+
self.logger.info(
632+
"Skipping scaling (not a DataFrame, likely Time Series mode)."
633+
)
634+
return
635+
611636
features_before = self.X_train.shape[1]
612637
scale = self.local_param_dict.get("scale")
613638
if scale:
@@ -650,6 +675,12 @@ def _scale_features(self):
650675

651676
def _select_features_by_importance(self):
652677
"""Selects features based on importance scores if configured."""
678+
if not isinstance(self.X_train, pd.DataFrame):
679+
self.logger.info(
680+
"Skipping feature selection (not a DataFrame, likely Time Series mode)."
681+
)
682+
return
683+
653684
target_n_features = self.local_param_dict.get("feature_n")
654685

655686
if target_n_features is not None and target_n_features < 100:
@@ -752,6 +783,12 @@ def _select_features_by_importance(self):
752783

753784
def _apply_embeddings(self):
754785
"""Applies feature embedding/dimensionality reduction if configured."""
786+
if not isinstance(self.X_train, pd.DataFrame):
787+
self.logger.info(
788+
"Skipping embeddings (not a DataFrame, likely Time Series mode)."
789+
)
790+
return
791+
755792
if self.local_param_dict.get("use_embedding", False):
756793
features_before = self.X_train.shape[1]
757794
self.logger.info("Applying embeddings...")
@@ -925,16 +962,9 @@ def _finalize_pipeline(self):
925962
# Final definitive assertion before exiting the data pipeline.
926963
# This ensures that the X_train and y_train that will be passed to the
927964
# model training steps are perfectly aligned.
928-
try:
929-
assert_index_equal(self.X_train.index, self.y_train.index)
930-
self.logger.info(
931-
"Final data alignment check PASSED. X_train and y_train indices are identical."
932-
)
933-
except AssertionError:
934-
self.logger.error(
935-
"CRITICAL: Final data alignment check FAILED. X_train and y_train indices are NOT identical."
936-
)
937-
raise
965+
self._assert_index_alignment(
966+
self.X_train, self.y_train, "Final data alignment check"
967+
)
938968

939969
def _compile_and_log_feature_transformations(self, error_occurred: bool = False):
940970
"""Compiles the feature transformation log and displays it."""

ml_grid/pipeline/data_train_test_split.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ def get_data_split(
4242
random.seed(1234)
4343
np.random.seed(1234)
4444

45+
# --- Handle Numpy Inputs (e.g. from Time Series mode) ---
46+
if isinstance(y, np.ndarray):
47+
y = pd.Series(y)
48+
49+
# Ensure X is a pandas DataFrame if it's 2D, to support column access if resampling is used.
50+
# If X is >2D (e.g. time series), it stays as numpy array.
51+
if isinstance(X, np.ndarray) and X.ndim == 2:
52+
X = pd.DataFrame(X)
53+
4554
# Check if data is valid
4655
if not is_valid_shape(X):
4756
local_param_dict["resample"] = None
@@ -138,7 +147,11 @@ def get_data_split(
138147
# --- Fallback for single-class training set ---
139148
# If the random split resulted in a training set with only 1 class (but we had 2+ available),
140149
# we attempt to move a sample from the test set to the training set to prevent model failure.
141-
if y_train.nunique() < 2 and y_train_processed.nunique() >= 2:
150+
if (
151+
y_train.nunique() < 2
152+
and y_train_processed.nunique() >= 2
153+
and isinstance(X_train, pd.DataFrame)
154+
):
142155
logger.warning(
143156
"y_train contains only 1 class after split. Attempting to move a sample from X_test to X_train to ensure class presence."
144157
)

ml_grid/pipeline/grid_search_cross_validate.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import multiprocessing
44
import joblib
55
import warnings
6+
import os
7+
import sys
68
from typing import Any, Dict, List, Optional, Union
79

810
import numpy as np
@@ -15,7 +17,9 @@
1517
from sklearn import metrics
1618
from pandas.testing import assert_index_equal
1719
from xgboost.core import XGBoostError
20+
from ml_grid.model_classes.AutoKerasClassifierWrapper import AutoKerasClassifierWrapper
1821
from ml_grid.model_classes.H2OAutoMLClassifier import H2OAutoMLClassifier
22+
from ml_grid.model_classes.FLAMLClassifierWrapper import FLAMLClassifierWrapper
1923
from ml_grid.model_classes.H2OGBMClassifier import H2OGBMClassifier
2024
from ml_grid.model_classes.H2ODRFClassifier import H2ODRFClassifier
2125
from ml_grid.model_classes.H2OGAMClassifier import H2OGAMClassifier
@@ -158,6 +162,34 @@ def __init__(
158162
# One-time TF/GPU Setup
159163
if is_gpu_model and not _TF_INITIALIZED:
160164
try:
165+
# --- FIX for libdevice error ---
166+
# Set XLA_FLAGS to point to the CUDA toolkit installed by pip.
167+
# This is crucial for XLA to find the libdevice library for GPU compilation.
168+
if "XLA_FLAGS" not in os.environ:
169+
# Find site-packages directory
170+
site_packages_path = next(
171+
(p for p in sys.path if "site-packages" in p), None
172+
)
173+
if site_packages_path:
174+
# The 'nvidia-cuda-nvcc-cu12' package installs the compiler toolkit here.
175+
# XLA needs this path to find the 'nvvm/libdevice' directory.
176+
cuda_path = os.path.join(
177+
site_packages_path, "nvidia", "cuda_nvcc"
178+
)
179+
180+
if os.path.exists(cuda_path):
181+
self.logger.info(
182+
f"Found CUDA compiler toolkit at {cuda_path}. Setting XLA_FLAGS."
183+
)
184+
os.environ["XLA_FLAGS"] = (
185+
f"--xla_gpu_cuda_data_dir={cuda_path}"
186+
)
187+
else:
188+
self.logger.warning(
189+
"Could not find 'nvidia/cuda_nvcc' directory. Falling back to site-packages root. "
190+
"Install 'nvidia-cuda-nvcc-cu12' for a reliable setup."
191+
)
192+
161193
gpu_devices = tf.config.experimental.list_physical_devices("GPU")
162194
if gpu_devices:
163195
for device in gpu_devices:
@@ -523,7 +555,10 @@ def __init__(
523555
# Convert y to numpy for ALL models
524556
y_train_search = self._optimize_y(y_train_reset)
525557

526-
if not is_h2o_model:
558+
# Pass DataFrame to H2O and FLAML, which need column info.
559+
# Other models get a numpy array for performance.
560+
is_flaml_model = isinstance(current_algorithm, FLAMLClassifierWrapper)
561+
if not is_h2o_model and not is_flaml_model:
527562
X_train_search = X_train_reset.values
528563
else:
529564
X_train_search = X_train_reset
@@ -579,16 +614,6 @@ def __init__(
579614
# Restore the original grid_n_jobs setting
580615
self.global_parameters.grid_n_jobs = original_grid_n_jobs
581616

582-
# Skip final CV in test mode
583-
if not failed and getattr(self.global_parameters, "test_mode", False):
584-
self.logger.info(
585-
"Test mode enabled. Skipping final cross-validation for speed."
586-
)
587-
self.grid_search_cross_validate_score_result = 0.5 # Return a valid float
588-
# Final cleanup for H2O models
589-
self._shutdown_h2o_if_needed(current_algorithm)
590-
return
591-
592617
if not failed and self.global_parameters.verbose >= 3:
593618
self.logger.debug("Fitting final model")
594619

@@ -612,19 +637,26 @@ def __init__(
612637

613638
is_h2o_model = isinstance(current_algorithm, H2O_MODEL_TYPES)
614639
is_keras_model = isinstance(current_algorithm, keras_model_types)
640+
is_flaml_model = isinstance(current_algorithm, FLAMLClassifierWrapper)
641+
is_autokeras_model = isinstance(current_algorithm, AutoKerasClassifierWrapper)
615642

616643
# H2O and Keras models require single-threaded execution for CV
617-
final_cv_n_jobs = 1 if is_h2o_model or is_keras_model else grid_n_jobs
644+
final_cv_n_jobs = (
645+
1
646+
if is_h2o_model or is_keras_model or is_flaml_model or is_autokeras_model
647+
else grid_n_jobs
648+
)
618649
if final_cv_n_jobs == 1:
619650
self.logger.debug(
620-
"H2O or Keras model detected. Forcing n_jobs=1 for final cross-validation."
651+
"H2O, Keras, FLAML, or AutoKeras model detected. Forcing n_jobs=1 for final cross-validation."
621652
)
622653

623654
try:
624655
if failed:
625656
raise TimeoutError
626657

627-
if isinstance(current_algorithm, H2O_MODEL_TYPES):
658+
# H2O, FLAML and AutoKeras require pandas DataFrame to handle categorical features correctly.
659+
if is_h2o_model or is_flaml_model or is_autokeras_model:
628660
X_train_final = self.X_train # Pass DataFrame directly
629661
y_train_final = self._optimize_y(self.y_train)
630662
else:

ml_grid/pipeline/test_data_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def setUp(self):
4444
global_parameters.verbose = 0 # Keep test output clean
4545
global_parameters.error_raise = True
4646
global_parameters.bayessearch = False # Explicitly set search mode
47+
global_parameters.test_mode = True # Enable fast test mode
4748

4849
# Define a base configuration for the pipeline
4950
self.base_local_param_dict = {

ml_grid/util/validate_parameters.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Functions to validate model-specific hyperparameters before grid search."""
22

33
import logging
4-
from typing import Any, Dict
4+
from typing import Any, Dict, List, Union
55

66
from sklearn.neighbors import KNeighborsClassifier
77
from xgboost import XGBClassifier
@@ -11,23 +11,28 @@
1111

1212

1313
def validate_knn_parameters(
14-
parameters: Dict[str, Any], ml_grid_object: Any
15-
) -> Dict[str, Any]:
14+
parameters: Union[Dict[str, Any], List[Dict[str, Any]]], ml_grid_object: Any
15+
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
1616
"""Validates the `n_neighbors` parameter for KNN classifiers.
1717
1818
This function ensures that the values for `n_neighbors` do not exceed the
1919
number of samples in the training data. If a value is too large, it is
2020
capped at `n_samples - 1`.
2121
2222
Args:
23-
parameters (Dict[str, Any]): The dictionary of parameters to validate.
23+
parameters (Union[Dict[str, Any], List[Dict[str, Any]]]): The dictionary or list of dictionaries of parameters to validate.
2424
ml_grid_object (Any): The main pipeline object containing the training
2525
data (`X_train`).
2626
2727
Returns:
28-
Dict[str, Any]: The validated parameters dictionary.
28+
Union[Dict[str, Any], List[Dict[str, Any]]]: The validated parameters.
2929
"""
3030

31+
if isinstance(parameters, list):
32+
for i in range(len(parameters)):
33+
parameters[i] = validate_knn_parameters(parameters[i], ml_grid_object)
34+
return parameters
35+
3136
logger = logging.getLogger("ml_grid")
3237
# Get the number of samples in the training data
3338
logger.debug("Validating KNN parameters")
@@ -58,23 +63,31 @@ def validate_knn_parameters(
5863

5964

6065
def validate_XGB_parameters(
61-
parameters: Dict[str, Any], ml_grid_object: Any
62-
) -> Dict[str, Any]:
66+
parameters: Union[Dict[str, Any], List[Dict[str, Any]]], ml_grid_object: Any
67+
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
6368
"""Validates the `max_bin` parameter for XGBoost.
6469
6570
This function checks that the max_bin values are greater than or equal to 2,
6671
and if not, it sets them to 2.
6772
6873
Args:
69-
parameters (Dict[str, Any]): The dictionary of parameters to validate.
74+
parameters (Union[Dict[str, Any], List[Dict[str, Any]]]): The dictionary or list of dictionaries of parameters to validate.
7075
ml_grid_object (Any): The main pipeline object (currently unused).
7176
7277
Returns:
73-
Dict[str, Any]: The validated parameters dictionary.
78+
Union[Dict[str, Any], List[Dict[str, Any]]]: The validated parameters.
7479
"""
7580

81+
if isinstance(parameters, list):
82+
for i in range(len(parameters)):
83+
parameters[i] = validate_XGB_parameters(parameters[i], ml_grid_object)
84+
return parameters
85+
7686
max_bin_array = parameters.get("max_bin")
7787

88+
if max_bin_array is None:
89+
return parameters
90+
7891
# Iterate over each value in the max_bin array
7992
for i in range(len(max_bin_array)):
8093
# Check if the value is less than 2
@@ -89,17 +102,19 @@ def validate_XGB_parameters(
89102

90103

91104
def validate_parameters_helper(
92-
algorithm_implementation: Any, parameters: Dict[str, Any], ml_grid_object: Any
93-
) -> Dict[str, Any]:
105+
algorithm_implementation: Any,
106+
parameters: Union[Dict[str, Any], List[Dict[str, Any]]],
107+
ml_grid_object: Any,
108+
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
94109
"""Dispatches to the correct parameter validation function based on algorithm type.
95110
96111
Args:
97112
algorithm_implementation (Any): The scikit-learn estimator instance.
98-
parameters (Dict[str, Any]): The dictionary of parameters to validate.
113+
parameters (Union[Dict[str, Any], List[Dict[str, Any]]]): The parameters to validate.
99114
ml_grid_object (Any): The main pipeline object containing training data.
100115
101116
Returns:
102-
Dict[str, Any]: The validated parameters dictionary.
117+
Union[Dict[str, Any], List[Dict[str, Any]]]: The validated parameters.
103118
"""
104119

105120
if isinstance(algorithm_implementation, KNeighborsClassifier):

0 commit comments

Comments
 (0)