Skip to content

Commit 07d4a77

Browse files
committed
feat(validation): Add generic parameter filtering and optional XGBoost support
This commit significantly refactors the parameter validation logic to be more robust and flexible. Key changes: - **Generic Parameter Filtering**: `validate_parameters_helper` now includes a generic fallback mechanism. It inspects the `algorithm_implementation` using `get_params()` and removes any keys from the parameter space that are not valid for that specific model. This prevents `TypeError` exceptions during hyperparameter search when a parameter space contains keys not recognized by the estimator. - **Optional XGBoost**: The import of `XGBClassifier` has been moved inside a `try...except` block. This makes `xgboost` an optional dependency, allowing the library to run without it installed. - **Improved Type Checking**: Added `isinstance(..., list)` checks in `validate_knn_parameters` and `validate_XGB_parameters` to prevent errors when parameter values are not lists as expected. - **Cleanup**: Minor whitespace removal in `tests/conftest.py`.
1 parent 456d477 commit 07d4a77

2 files changed

Lines changed: 46 additions & 26 deletions

File tree

ml_grid/util/validate_parameters.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Any, Dict, List, Union
55

66
from sklearn.neighbors import KNeighborsClassifier
7-
from xgboost import XGBClassifier
87

98
# from ml_grid.model_classes.knn_gpu_classifier_class import KNNGpuWrapperClass
109
# from ml_grid.model_classes.knn_wrapper_class import KNNWrapper
@@ -49,15 +48,15 @@ def validate_knn_parameters(
4948
logger.debug(f" Initial n_neighbors: {n_neighbors}")
5049

5150
# Check if any n_neighbors values are too large
52-
if n_neighbors is not None:
51+
if n_neighbors is not None and isinstance(n_neighbors, list):
5352
for i in range(len(n_neighbors)):
5453
if n_neighbors[i] > max_neighbors:
5554
logger.debug(
5655
f" Capping n_neighbors[{i}] from {n_neighbors[i]} to {max_neighbors}"
5756
)
5857
n_neighbors[i] = max_neighbors
5958

60-
parameters["n_neighbors"] = n_neighbors
59+
parameters["n_neighbors"] = n_neighbors
6160
# Return the validated parameters
6261
return parameters
6362

@@ -85,7 +84,7 @@ def validate_XGB_parameters(
8584

8685
max_bin_array = parameters.get("max_bin")
8786

88-
if max_bin_array is None:
87+
if max_bin_array is None or not isinstance(max_bin_array, list):
8988
return parameters
9089

9190
# Iterate over each value in the max_bin array
@@ -106,7 +105,13 @@ def validate_parameters_helper(
106105
parameters: Union[Dict[str, Any], List[Dict[str, Any]]],
107106
ml_grid_object: Any,
108107
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
109-
"""Dispatches to the correct parameter validation function based on algorithm type.
108+
"""Dispatches to model-specific validation or performs generic filtering.
109+
110+
This function first checks for model-specific validation routines (e.g., for
111+
KNN, XGBoost). If no specific routine is found, it performs a generic
112+
validation that removes any parameters from the search space that are not
113+
valid for the given algorithm instance. This prevents `TypeError` exceptions
114+
from scikit-learn's search classes.
110115
111116
Args:
112117
algorithm_implementation (Any): The scikit-learn estimator instance.
@@ -116,29 +121,45 @@ def validate_parameters_helper(
116121
Returns:
117122
Union[Dict[str, Any], List[Dict[str, Any]]]: The validated parameters.
118123
"""
124+
logger = logging.getLogger("ml_grid")
119125

126+
# --- Model-specific validation ---
120127
if isinstance(algorithm_implementation, KNeighborsClassifier):
121-
122-
parameters = validate_knn_parameters(parameters, ml_grid_object)
123-
128+
return validate_knn_parameters(parameters, ml_grid_object)
129+
130+
try:
131+
from xgboost import XGBClassifier
132+
133+
if isinstance(algorithm_implementation, XGBClassifier):
134+
return validate_XGB_parameters(parameters, ml_grid_object)
135+
except ImportError:
136+
logger.debug("XGBoost not installed, skipping XGBoost-specific validation.")
137+
pass
138+
139+
# --- Generic fallback: Filter invalid parameters ---
140+
try:
141+
valid_params = algorithm_implementation.get_params().keys()
142+
except Exception:
143+
logger.warning(
144+
f"Could not get params for {algorithm_implementation.__class__.__name__}. Skipping generic validation."
145+
)
124146
return parameters
125147

126-
# elif type(algorithm_implementation) == KNNWrapper:
127-
128-
# parameters = validate_knn_parameters(parameters, ml_grid_object)
129-
130-
# return parameters
131-
132-
# elif isinstance(algorithm_implementation, KNNGpuWrapperClass):
148+
def _filter_dict(param_dict: Dict) -> Dict:
149+
"""Filters a single parameter dictionary."""
150+
if not isinstance(param_dict, dict):
151+
return param_dict
152+
validated_dict = {k: v for k, v in param_dict.items() if k in valid_params}
153+
removed_keys = set(param_dict.keys()) - set(validated_dict.keys())
154+
if removed_keys:
155+
logger.debug(
156+
f"Removed invalid keys for {algorithm_implementation.__class__.__name__}: {removed_keys}"
157+
)
158+
return validated_dict
133159

134-
# parameters = validate_knn_parameters(parameters, ml_grid_object)
135-
136-
# return parameters
137-
138-
elif isinstance(algorithm_implementation, XGBClassifier):
139-
parameters = validate_XGB_parameters(parameters, ml_grid_object)
140-
141-
return parameters
160+
if isinstance(parameters, list):
161+
return [_filter_dict(p) for p in parameters]
162+
elif isinstance(parameters, dict):
163+
return _filter_dict(parameters)
142164

143-
else:
144-
return parameters
165+
return parameters

tests/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import h2o
88
import pytest
99

10-
1110
# --- Tame TensorFlow ---
1211
# Set log level to suppress info/warnings before importing
1312
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

0 commit comments

Comments
 (0)