Skip to content

Commit 3511659

Browse files
author
SamoraHunter
committed
fixed global param setting for n_iter
1 parent 56d4537 commit 3511659

3 files changed

Lines changed: 41 additions & 2 deletions

File tree

config_hyperopt.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ global_params:
99
h2o_show_progress: false
1010
# Number of iterations for RandomizedSearchCV and BayesSearchCV
1111
n_iter: 2
12+
max_param_space_iter_value : 10
1213

1314
# Experiment settings for the hyperopt run
1415
experiment:

ml_grid/pipeline/grid_search_cross_validate.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,10 @@ def __init__(
171171
self.global_params.max_param_space_iter_value
172172
) # hard limit on param space exploration
173173

174+
# Allow local override for max_param_space_iter_value
175+
if self.ml_grid_object_iter.local_param_dict.get("max_param_space_iter_value") is not None:
176+
max_param_space_iter_value = self.ml_grid_object_iter.local_param_dict.get("max_param_space_iter_value")
177+
174178
if "svc" in method_name.lower():
175179
self.X_train = scale_data(self.X_train)
176180
self.X_test = scale_data(self.X_test)
@@ -281,8 +285,31 @@ def __init__(
281285
parameter_space = new_parameter_space
282286

283287
# Use the new n_iter parameter from the config
284-
# Default to 50 if not present, preventing AttributeError
285-
n_iter_v = getattr(self.global_params, "n_iter", 2)
288+
# Default to 2 if not present, preventing AttributeError
289+
try:
290+
n_iter_v = getattr(self.global_params, "n_iter", 2)
291+
if n_iter_v is None:
292+
n_iter_v = 2
293+
n_iter_v = int(n_iter_v)
294+
except (ValueError, TypeError):
295+
self.logger.warning("Invalid or missing n_iter in global_params. Defaulting to 2.")
296+
n_iter_v = 2
297+
298+
# Allow local override from run_params/local_param_dict
299+
local_n_iter = self.ml_grid_object_iter.local_param_dict.get("n_iter")
300+
if local_n_iter is not None:
301+
try:
302+
n_iter_v = int(local_n_iter)
303+
self.logger.info(f"Overriding global n_iter with local value: {n_iter_v}")
304+
except (ValueError, TypeError):
305+
self.logger.warning(f"Invalid local n_iter value: {local_n_iter}. Ignoring override.")
306+
307+
if max_param_space_iter_value is not None:
308+
if n_iter_v > max_param_space_iter_value:
309+
self.logger.info(
310+
f"Capping n_iter ({n_iter_v}) to max_param_space_iter_value ({max_param_space_iter_value})"
311+
)
312+
n_iter_v = max_param_space_iter_value
286313

287314
# For GridSearchCV, n_iter is not used, but we calculate the grid size for logging.
288315
if not self.global_params.bayessearch and not random_grid_search:

ml_grid/pipeline/main.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import logging
22
import traceback
3+
import glob
4+
import os
5+
import yaml
36
from typing import Any, Dict, List, Tuple
47

58
import numpy as np
@@ -79,6 +82,10 @@ def __init__(self, local_param_dict: Dict[str, Any], **kwargs):
7982
"""
8083
self.global_params = global_parameters
8184

85+
# Update global parameters if provided in kwargs
86+
if "global_params" in kwargs and isinstance(kwargs["global_params"], dict):
87+
self.global_params.update_parameters(**kwargs["global_params"])
88+
8289
self.logger = logging.getLogger("ml_grid")
8390

8491
self.verbose = self.global_params.verbose
@@ -99,6 +106,10 @@ def __init__(self, local_param_dict: Dict[str, Any], **kwargs):
99106
}
100107
self.ml_grid_object = pipe(**pipe_kwargs)
101108

109+
# Propagate n_iter from global_params to local_param_dict
110+
# This ensures the value persists across process boundaries (pickling) where the singleton might be reset
111+
self.ml_grid_object.local_param_dict["n_iter"] = self.global_params.n_iter
112+
102113
self.error_raise = self.global_params.error_raise
103114

104115
self.sub_sample_param_space_pct = self.global_params.sub_sample_param_space_pct

0 commit comments

Comments
 (0)