Skip to content

Commit 55500ca

Browse files
committed
additional documentation for api reference
1 parent 5aa6414 commit 55500ca

2 files changed

Lines changed: 77 additions & 3 deletions

File tree

ml_grid/pipeline/hyperparameter_search.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,33 @@
2020
class HyperparameterSearch:
2121
"""Orchestrates hyperparameter search using GridSearchCV, RandomizedSearchCV, or BayesSearchCV."""
2222

23+
algorithm: BaseEstimator
24+
"""The scikit-learn compatible estimator instance."""
25+
26+
parameter_space: Union[Dict, List[Dict]]
27+
"""The hyperparameter search space."""
28+
29+
method_name: str
30+
"""The name of the algorithm."""
31+
32+
global_params: global_parameters
33+
"""A reference to the global parameters singleton instance."""
34+
35+
sub_sample_pct: int
36+
"""
37+
Percentage of the parameter space to sample for randomized search.
38+
Defaults to 100.
39+
"""
40+
41+
max_iter: int
42+
"""
43+
The maximum number of iterations for randomized or Bayesian search.
44+
Defaults to 100.
45+
"""
46+
47+
ml_grid_object: Any
48+
"""The main pipeline object containing data and other parameters."""
49+
2350
def __init__(
2451
self,
2552
algorithm: BaseEstimator,

ml_grid/pipeline/main.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,53 @@
1212
class run:
1313
"""Orchestrates the hyperparameter search for a list of models."""
1414

15+
global_params: global_parameters
16+
"""A reference to the global parameters singleton instance."""
17+
18+
verbose: int
19+
"""The verbosity level for logging, inherited from global parameters."""
20+
21+
error_raise: bool
22+
"""A flag to control error handling. If True, exceptions will be raised."""
23+
24+
ml_grid_object: pipe
25+
"""The main data pipeline object, containing data and model configurations."""
26+
27+
sub_sample_param_space_pct: float
28+
"""The percentage of the parameter space to sample in a randomized search."""
29+
30+
parameter_space_size: str
31+
"""The size of the parameter space for base learners (e.g., 'medium', 'xsmall')."""
32+
33+
model_class_list: List[Any]
34+
"""A list of instantiated model class objects to be evaluated in this run."""
35+
36+
pg_list: List[int]
37+
"""A list containing the calculated size of the parameter grid for each model."""
38+
39+
mean_parameter_space_val: float
40+
"""The mean size of the parameter spaces across all models in the run."""
41+
42+
sub_sample_parameter_val: int
43+
"""The calculated number of iterations for randomized search, based on `sub_sample_param_space_pct`."""
44+
45+
arg_list: List[Tuple]
46+
"""A list of argument tuples, one for each model, to be passed to the grid search function."""
47+
48+
multiprocess: bool
49+
"""A flag to enable or disable multiprocessing for running grid searches in parallel."""
50+
51+
local_param_dict: Dict[str, Any]
52+
"""A dictionary of parameters for the current experimental run."""
53+
54+
model_error_list: List[List[Any]]
55+
"""A list to store details of any errors encountered during model training."""
56+
57+
highest_score: float
58+
"""The highest score achieved across all successful model runs in the execute step."""
59+
60+
61+
1562
def __init__(self, ml_grid_object: pipe, local_param_dict: Dict[str, Any]):
1663
"""Initializes the run class.
1764
@@ -132,7 +179,7 @@ def execute(self) -> Tuple[List[List[Any]], float]:
132179
"""
133180

134181
self.model_error_list = []
135-
182+
self.highest_score = 0
136183
highest_score = 0 # for optimisation
137184

138185
if self.multiprocess:
@@ -158,7 +205,7 @@ def multi_run_wrapper(args: Tuple) -> Any:
158205
# algorithm_implementation = LogisticRegression_class(parameter_space_size=self.parameter_space_size).algorithm_implementation, parameter_space = self.arg_list[k][1], method_name=self.arg_list[k][2], X = self.arg_list[k][3], y=self.arg_list[k][4]
159206
).grid_search_cross_validate_score_result
160207

161-
highest_score = max(highest_score, res)
208+
self.highest_score = max(self.highest_score, res)
162209
print(f"highest score: {highest_score}")
163210

164211
except CatBoostError as e:
@@ -191,4 +238,4 @@ def multi_run_wrapper(args: Tuple) -> Any:
191238
# return highest score from run for additional optimisation:
192239

193240

194-
return self.model_error_list, highest_score
241+
return self.model_error_list, self.highest_score

0 commit comments

Comments
 (0)