@@ -24,7 +24,9 @@ def time_limit(seconds):
2424 try :
2525 seconds_int = int (seconds )
2626 except (ValueError , TypeError ):
27- logging .getLogger ("ml_grid" ).warning (f"Invalid timeout value: { seconds } . Timeout disabled." )
27+ logging .getLogger ("ml_grid" ).warning (
28+ f"Invalid timeout value: { seconds } . Timeout disabled."
29+ )
2830 yield
2931 return
3032
@@ -33,9 +35,12 @@ def time_limit(seconds):
3335 return
3436
3537 if not hasattr (signal , "SIGALRM" ):
36- logging .getLogger ("ml_grid" ).warning ("Timeout not supported on this platform (SIGALRM missing)." )
38+ logging .getLogger ("ml_grid" ).warning (
39+ "Timeout not supported on this platform (SIGALRM missing)."
40+ )
3741 yield
3842 return
43+
3944 def signal_handler (signum , frame ):
4045 raise TimeoutError (f"Timeout of { seconds } s reached" )
4146
@@ -65,6 +70,7 @@ def signal_handler(signum, frame):
6570 remaining_outer = max (1 , int (previous_remaining - elapsed ))
6671 signal .alarm (remaining_outer )
6772
73+
6874class run :
6975 """Orchestrates the hyperparameter search for a list of models."""
7076
@@ -294,16 +300,18 @@ def execute_single_model(self, args: Tuple) -> float:
294300 """
295301 try :
296302 self .logger .info (f"Starting grid search for { args [2 ]} ..." )
297-
303+
298304 # Retrieve timeout from local_param_dict via ml_grid_object (args[3])
299305 timeout = args [3 ].local_param_dict .get ("model_eval_time_limit" )
300306 if timeout is None :
301307 timeout = args [3 ].global_params .model_eval_time_limit
302-
308+
303309 with time_limit (timeout ):
304- gscv_instance = grid_search_cross_validate .grid_search_crossvalidate (* args )
310+ gscv_instance = grid_search_cross_validate .grid_search_crossvalidate (
311+ * args
312+ )
305313 score = gscv_instance .grid_search_cross_validate_score_result
306-
314+
307315 self .logger .info (f"Score for { args [2 ]} : { score :.4f} " )
308316 return score
309317
@@ -364,7 +372,7 @@ def multi_run_wrapper(args: Tuple) -> Any:
364372 self .logger .info (
365373 f"Starting grid search for { self .arg_list [k ][2 ]} ..."
366374 )
367-
375+
368376 timeout = self .local_param_dict .get ("model_eval_time_limit" )
369377 if timeout is None :
370378 timeout = self .global_params .model_eval_time_limit
@@ -383,7 +391,9 @@ def multi_run_wrapper(args: Tuple) -> Any:
383391 self .logger .info (f"Current highest score: { self .highest_score :.4f} " )
384392
385393 except TimeoutError as e :
386- self .logger .warning (f"Timeout occurred for { self .arg_list [k ][2 ]} : { e } " )
394+ self .logger .warning (
395+ f"Timeout occurred for { self .arg_list [k ][2 ]} : { e } "
396+ )
387397 self .model_error_list .append (
388398 [self .arg_list [k ][0 ], e , traceback .format_exc ()]
389399 )
0 commit comments