Skip to content

Commit 0711809

Browse files
author
SamoraHunter
committed
formatting
1 parent c8a81d8 commit 0711809

3 files changed

Lines changed: 26 additions & 11 deletions

File tree

ml_grid/pipeline/grid_search_cross_validate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
H2OStackedEnsembleClassifier,
6565
)
6666

67+
6768
class grid_search_crossvalidate:
6869

6970
def __init__(
@@ -441,10 +442,10 @@ def __init__(
441442
"train_score": np.array([0.5]),
442443
"test_recall": np.array([0.5]),
443444
}
444-
445+
445446
failed = False
446447
scores = None
447-
448+
448449
# Initialize start_time early
449450
start_time = time.time()
450451

ml_grid/pipeline/main.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ def time_limit(seconds):
2424
try:
2525
seconds_int = int(seconds)
2626
except (ValueError, TypeError):
27-
logging.getLogger("ml_grid").warning(f"Invalid timeout value: {seconds}. Timeout disabled.")
27+
logging.getLogger("ml_grid").warning(
28+
f"Invalid timeout value: {seconds}. Timeout disabled."
29+
)
2830
yield
2931
return
3032

@@ -33,9 +35,12 @@ def time_limit(seconds):
3335
return
3436

3537
if not hasattr(signal, "SIGALRM"):
36-
logging.getLogger("ml_grid").warning("Timeout not supported on this platform (SIGALRM missing).")
38+
logging.getLogger("ml_grid").warning(
39+
"Timeout not supported on this platform (SIGALRM missing)."
40+
)
3741
yield
3842
return
43+
3944
def signal_handler(signum, frame):
4045
raise TimeoutError(f"Timeout of {seconds}s reached")
4146

@@ -65,6 +70,7 @@ def signal_handler(signum, frame):
6570
remaining_outer = max(1, int(previous_remaining - elapsed))
6671
signal.alarm(remaining_outer)
6772

73+
6874
class run:
6975
"""Orchestrates the hyperparameter search for a list of models."""
7076

@@ -294,16 +300,18 @@ def execute_single_model(self, args: Tuple) -> float:
294300
"""
295301
try:
296302
self.logger.info(f"Starting grid search for {args[2]}...")
297-
303+
298304
# Retrieve timeout from local_param_dict via ml_grid_object (args[3])
299305
timeout = args[3].local_param_dict.get("model_eval_time_limit")
300306
if timeout is None:
301307
timeout = args[3].global_params.model_eval_time_limit
302-
308+
303309
with time_limit(timeout):
304-
gscv_instance = grid_search_cross_validate.grid_search_crossvalidate(*args)
310+
gscv_instance = grid_search_cross_validate.grid_search_crossvalidate(
311+
*args
312+
)
305313
score = gscv_instance.grid_search_cross_validate_score_result
306-
314+
307315
self.logger.info(f"Score for {args[2]}: {score:.4f}")
308316
return score
309317

@@ -364,7 +372,7 @@ def multi_run_wrapper(args: Tuple) -> Any:
364372
self.logger.info(
365373
f"Starting grid search for {self.arg_list[k][2]}..."
366374
)
367-
375+
368376
timeout = self.local_param_dict.get("model_eval_time_limit")
369377
if timeout is None:
370378
timeout = self.global_params.model_eval_time_limit
@@ -383,7 +391,9 @@ def multi_run_wrapper(args: Tuple) -> Any:
383391
self.logger.info(f"Current highest score: {self.highest_score:.4f}")
384392

385393
except TimeoutError as e:
386-
self.logger.warning(f"Timeout occurred for {self.arg_list[k][2]}: {e}")
394+
self.logger.warning(
395+
f"Timeout occurred for {self.arg_list[k][2]}: {e}"
396+
)
387397
self.model_error_list.append(
388398
[self.arg_list[k][0], e, traceback.format_exc()]
389399
)

ml_grid/util/project_score_save.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,11 @@ def update_score_log(
190190
# --- OPTIMIZATION: Pre-process targets for faster metric calculation ---
191191
# Convert to numpy arrays to avoid pandas overhead in sklearn metrics
192192
y_test_np = y_test.values if hasattr(y_test, "values") else y_test
193-
best_pred_np = best_pred_orig.values if hasattr(best_pred_orig, "values") else best_pred_orig
193+
best_pred_np = (
194+
best_pred_orig.values
195+
if hasattr(best_pred_orig, "values")
196+
else best_pred_orig
197+
)
194198

195199
# Ensure 1D arrays to prevent shape mismatch errors
196200
y_test_np = np.ravel(y_test_np)

0 commit comments

Comments
 (0)