@@ -223,6 +223,60 @@ def objective(trial):
223223 return objective
224224
225225
226+ def create_objective (
227+ config : dict , study_name : str , device_queue : queue .Queue
228+ ) -> callable :
229+ """
230+ Create the objective function for Optuna.
231+
232+ Args:
233+ config (dict): Configuration dictionary.
234+ study_name (str): Name of the study.
235+ device_queue (queue.Queue): Queue of available devices.
236+
237+ Returns:
238+ function: Objective function for Optuna.
239+ """
240+
241+ def objective (trial ):
242+ device , slot_id = device_queue .get ()
243+ try :
244+ try :
245+ return training_run (trial , device , slot_id , config , study_name )
246+ except torch .cuda .OutOfMemoryError as e :
247+ torch .cuda .empty_cache ()
248+ msg = repr (e ).strip ()
249+ if not msg :
250+ msg = "CUDA Out of Memory (no details provided)."
251+ trial .set_user_attr ("exception" , msg )
252+ tqdm .write (f"[Trial { trial .number } ] resulted in an OOM error." )
253+ # raise optuna.TrialPruned(f"OOM error in trial {trial.number}")
254+ if config .get ("multi_objective" , False ):
255+ # In multi-objective mode, we return a tuple
256+ return float (config .get ("loss_cap" , 20 )), float (10 )
257+ else :
258+ # In single objective mode, we return a single value
259+ return float (config .get ("loss_cap" , 20 ))
260+ except optuna .TrialPruned as e :
261+ msg = repr (e ).strip ()
262+ trial .set_user_attr ("exception" , msg )
263+ raise
264+ except Exception as e :
265+ torch .cuda .empty_cache ()
266+ msg = repr (e ).strip ()
267+ if not msg :
268+ msg = "Unknown error occurred."
269+ tqdm .write (
270+ f"Trial { trial .number } failed due to an unexpected error: { msg } "
271+ )
272+ trial .set_user_attr ("exception" , msg )
273+ raise optuna .TrialPruned (f"Error in trial { trial .number } : { msg } " )
274+ finally :
275+ device_queue .put ((device , slot_id ))
276+
277+ return objective
278+
279+
226280def training_run (
227281 trial : optuna .Trial , device : str , slot_id : int , config : dict , study_name : str
228282) -> float | tuple [float , float ]:
@@ -244,7 +298,6 @@ def training_run(
244298
245299 download_data (config ["dataset" ]["name" ], verbose = False )
246300
247- # Load full data and parameters
248301 (
249302 (train_data , test_data , _ ),
250303 (train_params , test_params , _ ),
@@ -263,21 +316,29 @@ def training_run(
263316 )
264317
265318 subset_factor = config ["dataset" ].get ("subset_factor" , 1 )
266- # Get the appropriate subset of the training data
267- # We nevertheless use the full test data to measure performance.
268319 train_data = train_data [::subset_factor ]
269320 train_params = train_params [::subset_factor ] if train_params is not None else None
270321
271322 set_random_seeds (config ["seed" ], device = device )
272323 surr_name = config ["surrogate" ]["name" ]
273- suggested_params = make_optuna_params (trial , config ["optuna_params" ])
274- n_params = train_params .shape [1 ] if train_params is not None else 0
275324
325+ # Load base (best) config from disk as you already do
326+ model_config = get_model_config (surr_name , config )
327+
328+ # Decide search space
329+ if config .get ("fine" , False ):
330+ fine_space = config .get ("fine_space" )
331+ suggested_params = make_optuna_params (trial , fine_space )
332+ else :
333+ suggested_params = make_optuna_params (trial , config ["optuna_params" ])
334+
335+ n_params = train_params .shape [1 ] if train_params is not None else 0
276336 n_timesteps = train_data .shape [1 ]
277337 n_quantities = train_data .shape [2 ]
278338 surrogate_class = get_surrogate (surr_name )
279- model_config = get_model_config ( surr_name , config )
339+
280340 model_config .update (suggested_params )
341+
281342 model = surrogate_class (
282343 device = device ,
283344 n_quantities = n_quantities ,
@@ -312,30 +373,21 @@ def training_run(
312373 multi_objective = config ["multi_objective" ],
313374 )
314375
315- # criterion = torch.nn.MSELoss()
316376 preds , targets = model .predict (test_loader , leave_log = True )
317377 p99_dex = torch .quantile (
318378 (preds - targets ).abs ().flatten (), float (config ["target_percentile" ])
319379 ).item ()
320- # cap the loss to prevent exploding values
321380 p99_dex = min (p99_dex , config .get ("loss_cap" , 20 ))
322381
323- # Extract the study name without the timestamp/suffix part
324382 parts = study_name .split ("_" )
325383 sname = "_" .join (parts [:- 1 ]) if len (parts ) > 1 else study_name
326384
327385 savepath = os .path .join ("tuned" , sname , "models" )
328386 os .makedirs (savepath , exist_ok = True )
329387 model_name = f"{ surr_name .lower ()} _{ trial .number } "
330- model .save (
331- model_name = model_name ,
332- base_dir = "" ,
333- training_id = savepath ,
334- )
388+ model .save (model_name = model_name , base_dir = "" , training_id = savepath )
335389
336- # Check if we're running multi-objective optimisation
337390 if config ["multi_objective" ]:
338- # Measure inference time
339391 with _inference_time_lock :
340392 inference_times = measure_inference_time (model , test_loader )
341393 return p99_dex , np .mean (inference_times )
@@ -405,3 +457,54 @@ def is_bad(tr):
405457 f"\n [Study] Warmup complete. Runtime threshold set to { threshold :.1f} s "
406458 f"(mean = { mean_ :.1f} s, std = { std_ :.1f} s) over trials { used_trial_numbers } ."
407459 )
460+
461+
462+ def _bounds_around (
463+ v : float , factor : float = 10.0 , lo : float | None = None , hi : float | None = None
464+ ) -> tuple [float , float ]:
465+ low , high = float (v ) / factor , float (v ) * factor
466+ if lo is not None :
467+ low = max (low , lo )
468+ if hi is not None :
469+ high = min (high , hi )
470+ # avoid degenerate ranges
471+ if high <= low :
472+ eps = max (abs (v ) * 1e-3 , 1e-12 )
473+ low , high = float (v ) - eps , float (v ) + eps
474+ return low , high
475+
476+
477+ def build_fine_optuna_params (model_config : dict ) -> dict :
478+ keys = (
479+ "learning_rate" ,
480+ "beta" ,
481+ "poly_power" ,
482+ "eta_min" ,
483+ "regularization_factor" ,
484+ "momentum" ,
485+ )
486+ space : dict [str , dict ] = {}
487+ for k in keys :
488+ if k not in model_config :
489+ continue
490+ val = model_config [k ]
491+ if not isinstance (val , (int , float )) or val == 0 :
492+ continue
493+ lo , hi = _bounds_around (
494+ val ,
495+ factor = 10.0 ,
496+ lo = 1e-12 if k != "momentum" else 0.0 ,
497+ hi = 0.999 if k == "momentum" else None ,
498+ )
499+ space [k ] = {"type" : "float" , "low" : lo , "high" : hi , "log" : True }
500+ return space
501+
502+
503+ def _is_valid_trial (t : optuna .trial .FrozenTrial ) -> bool :
504+ return (t .state in (TrialState .COMPLETE , TrialState .PRUNED )) and (
505+ "exception" not in t .user_attrs
506+ )
507+
508+
509+ def _count_valid_trials (study : optuna .Study ) -> int :
510+ return sum (1 for t in study .get_trials (deepcopy = False ) if _is_valid_trial (t ))
0 commit comments