11import dataclasses
22import os
3+ import time
34from abc import ABC , abstractmethod
45from datetime import datetime
56from typing import Any , TypeVar
67
78import numpy as np
9+ import optuna
810import torch
911import yaml
1012from torch import Tensor , nn
@@ -71,7 +73,7 @@ class AbstractSurrogateModel(ABC, nn.Module):
7173 model_name: str,
7274 subfolder: str,
7375 training_id: str,
74- data_params : dict,
76+ data_info : dict,
7577 ) -> None:
7678 Saves the model to disk.
7779
@@ -99,6 +101,7 @@ def __init__(
99101 device : str | None = None ,
100102 n_quantities : int = 29 ,
101103 n_timesteps : int = 100 ,
104+ n_parameters : int = 0 ,
102105 config : dict | None = None ,
103106 ):
104107 super ().__init__ ()
@@ -109,6 +112,7 @@ def __init__(
109112 self .device = device
110113 self .n_quantities = n_quantities
111114 self .n_timesteps = n_timesteps
115+ self .n_parameters = n_parameters
112116 self .L1 = nn .L1Loss ()
113117 self .config = config if config is not None else {}
114118 self .train_duration = None
@@ -265,7 +269,7 @@ def save(
265269 model_name (str): The name of the model.
266270 subfolder (str): The subfolder to save the model in.
267271 training_id (str): The training identifier.
268- data_params (dict): The data parameters.
272+ data_info (dict): The data parameters.
269273 """
270274
271275 # Make the model directory
@@ -329,7 +333,7 @@ def save(
329333
330334 save_attributes = {
331335 k : v
332- for k , v in self .__dict__ .items ()
336+ for k , v in self .__dict__ .copy (). items ()
333337 if k != "state_dict" and not k .startswith ("_" )
334338 }
335339 model_dict = {"state_dict" : self .state_dict (), "attributes" : save_attributes }
@@ -392,6 +396,7 @@ def setup_progress_bar(self, epochs: int, position: int, description: str):
392396 Returns:
393397 tqdm: The progress bar.
394398 """
399+
395400 bar_format = "{l_bar}{bar}| {n_fmt:>5}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt} {postfix}]"
396401 progress_bar = tqdm (
397402 range (epochs ),
@@ -401,6 +406,9 @@ def setup_progress_bar(self, epochs: int, position: int, description: str):
401406 bar_format = bar_format ,
402407 )
403408
409+ # Only used for time_pruning in multi objective optimisation
410+ self ._trial_start_time = time .time ()
411+
404412 return progress_bar
405413
406414 def denormalize (self , data : Tensor ) -> Tensor :
@@ -430,5 +438,64 @@ def denormalize(self, data: Tensor) -> Tensor:
430438
431439 return data
432440
441+ def time_pruning (self , current_epoch : int , total_epochs : int ) -> None :
442+ """
443+ Determine whether a trial should be pruned based on projected runtime,
444+ but only after a warmup period (10% of the total epochs).
445+
446+ Warmup: Do not prune if current_epoch is less than warmup_epochs.
447+ After warmup, compute the average epoch time, extrapolate the total runtime,
448+ and retrieve the threshold (runtime_threshold) from the study's user attributes.
449+ If the projected runtime exceeds the threshold, raise an optuna.TrialPruned exception.
450+
451+ Args:
452+ current_epoch (int): The current epoch count.
453+ total_epochs (int): The planned total number of epochs.
454+
455+ Raises:
456+ optuna.TrialPruned: If the projected runtime exceeds the threshold.
457+ """
458+ # Define warmup period based on 10% of total epochs.
459+ warmup_epochs = max (50 , int (total_epochs * 0.02 ))
460+ if current_epoch < warmup_epochs :
461+ # Do not attempt to prune before the warmup period is complete.
462+ # print(
463+ # f"[time_pruning] Warmup period: {current_epoch}/{warmup_epochs} epochs completed. Skipping pruning check."
464+ # )
465+ return
466+
467+ elapsed = time .time () - self ._trial_start_time
468+ completed_epochs = max (current_epoch , 1 )
469+ average_epoch_time = elapsed / completed_epochs
470+ projected_total_time = average_epoch_time * total_epochs
471+
472+ # Retrieve threshold from study's user attributes.
473+ if self .optuna_trial is not None and hasattr (self .optuna_trial , "study" ):
474+ threshold = self .optuna_trial .study .user_attrs .get (
475+ "runtime_threshold" , None
476+ )
477+ else :
478+ threshold = None
479+
480+ # print(
481+ # f"[time_pruning] Epoch: {current_epoch}/{total_epochs} | "
482+ # f"Elapsed: {elapsed:.1f}s | Avg per epoch: {average_epoch_time:.1f}s | "
483+ # f"Projected total: {projected_total_time:.1f}s | Threshold: {threshold:.1f}s"
484+ # )
485+
486+ if threshold is not None :
487+ if projected_total_time > threshold :
488+ if self .optuna_trial is not None :
489+ tqdm .write (
490+ f"[time_pruning] Projected total time { projected_total_time :.1f} s exceeds threshold { threshold :.1f} s. Pruning trial."
491+ )
492+ self .optuna_trial .set_user_attr (
493+ "prune_reason" ,
494+ f"Projected runtime { projected_total_time :.1f} s exceeds threshold { threshold :.1f} s" ,
495+ )
496+ raise optuna .TrialPruned (
497+ f"Projected total time { projected_total_time :.1f} s exceeds threshold { threshold :.1f} s"
498+ )
499+
433500
434501SurrogateModel = TypeVar ("SurrogateModel" , bound = AbstractSurrogateModel )
0 commit comments