Skip to content

Commit 942f9de

Browse files
committed
Fix TypeError in pipeline run initialization by detecting skopt spaces
- Added detection logic for `skopt` space objects (`Real`, `Integer`, `Categorical`) in `ml_grid/pipeline/main.py`. - Updated parameter grid size calculation to use `calculate_combinations` instead of `ParameterGrid` when Bayesian parameter spaces are detected, even if `bayessearch` is globally False. - Prevents `TypeError` when `ParameterGrid` encounters non-iterable `skopt` objects during pipeline initialization.
1 parent 18656a9 commit 942f9de

1 file changed

Lines changed: 17 additions & 6 deletions

File tree

ml_grid/pipeline/main.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
from sklearn.model_selection import ParameterGrid
10+
from skopt.space import Categorical, Integer, Real
1011

1112
from ml_grid.pipeline import grid_search_cross_validate
1213
from ml_grid.pipeline.data import pipe
@@ -181,17 +182,27 @@ def __init__(self, local_param_dict: Dict[str, Any], **kwargs):
181182

182183
for elem in self.model_class_list:
183184

184-
if not self.global_params.bayessearch:
185-
# ParameterGrid can now be called directly, as the model class
186-
# provides a grid-search-compatible parameter space.
185+
# Check if the parameter space contains skopt objects
186+
is_bayes_space = False
187+
space_to_check = elem.parameter_space
188+
if isinstance(space_to_check, list) and space_to_check:
189+
space_to_check = space_to_check[0] # Check the first dict
190+
191+
if isinstance(space_to_check, dict):
192+
if any(
193+
isinstance(v, (Real, Integer, Categorical))
194+
for v in space_to_check.values()
195+
):
196+
is_bayes_space = True
197+
198+
if not self.global_params.bayessearch and not is_bayes_space:
199+
# This is a true grid search space
187200
pg = ParameterGrid(elem.parameter_space)
188201
pg = len(pg)
189202
else:
190-
203+
# This handles both explicit bayessearch=True and mismatched grid spaces
191204
pg = calculate_combinations(elem.parameter_space, steps=10)
192205

193-
# pg = ParameterGrid(elem.parameter_space)
194-
195206
self.pg_list.append(pg)
196207

197208
if self.verbose >= 1:

0 commit comments

Comments
 (0)