@@ -199,6 +199,7 @@ def fit(
199199 epochs : int ,
200200 position : int ,
201201 description : str ,
202+ multi_objective : bool ,
202203 ) -> None :
203204 """
204205 Perform the training of the model. Sets the train_loss and test_loss attributes.
@@ -209,16 +210,20 @@ def fit(
209210 epochs (int): The number of epochs to train the model for.
210211 position (int): The position of the progress bar.
211212 description (str): The description of the progress bar.
213+ multi_objective (bool): Whether the training is multi-objective.
212214 """
213215 pass
214216
215- def predict (self , data_loader : DataLoader ) -> tuple [Tensor , Tensor ]:
217+ def predict (
218+ self , data_loader : DataLoader , denormalize : bool = True
219+ ) -> tuple [Tensor , Tensor ]:
216220 """
217221 Evaluate the model on the given dataloader.
218222
219223 Args:
220224 data_loader (DataLoader): The DataLoader object containing the data the
221225 model is evaluated on.
226+ denormalize (bool): Whether to denormalize the predictions and targets.
222227
223228 Returns:
224229 tuple[Tensor, Tensor]: The predictions and targets.
@@ -256,8 +261,9 @@ def predict(self, data_loader: DataLoader) -> tuple[Tensor, Tensor]:
256261 predictions = predictions [:processed_samples , ...]
257262 targets = targets [:processed_samples , ...]
258263
259- predictions = self .denormalize (predictions )
260- targets = self .denormalize (targets )
264+ if denormalize :
265+ predictions = self .denormalize (predictions )
266+ targets = self .denormalize (targets )
261267
262268 predictions = predictions .reshape (- 1 , self .n_timesteps , self .n_quantities )
263269 targets = targets .reshape (- 1 , self .n_timesteps , self .n_quantities )
@@ -419,7 +425,36 @@ def setup_progress_bar(self, epochs: int, position: int, description: str):
419425
420426 return progress_bar
421427
422- def denormalize (self , data : Tensor ) -> Tensor :
428+ def denormalize (self , data : Tensor , leave_log : bool = False ) -> Tensor :
429+ """
430+ Denormalize the data.
431+
432+ Args:
433+ data (np.ndarray): The data to denormalize.
434+ leave_log (bool): If True, do not exponentiate the data even if log10_transform is True.
435+
436+ Returns:
437+ np.ndarray: The denormalized data.
438+ """
439+ if self .normalisation is not None :
440+ if self .normalisation ["mode" ] == "disabled" :
441+ ...
442+ elif self .normalisation ["mode" ] == "minmax" :
443+ dmax = self .normalisation ["max" ]
444+ dmin = self .normalisation ["min" ]
445+ data = data .to ("cpu" )
446+ data = (data + 1 ) * (dmax - dmin ) / 2 + dmin
447+ elif self .normalisation ["mode" ] == "standardize" :
448+ mean = self .normalisation ["mean" ]
449+ std = self .normalisation ["std" ]
450+ data = data * std + mean
451+
452+ if self .normalisation ["log10_transform" ] and not leave_log :
453+ data = 10 ** data
454+
455+ return data
456+
457+ def denormalize_old (self , data : Tensor ) -> Tensor :
423458 """
424459 Denormalize the data.
425460
@@ -632,13 +667,13 @@ def validate(
632667 Only runs if (epoch % self.update_epochs) == 0.
633668 """
634669
635- # 1) If it's not time to check yet, do nothing.
670+ # If it's not time to check yet, do nothing.
636671 if epoch % self .update_epochs != 0 :
637672 return
638673
639674 index = epoch // self .update_epochs
640675
641- # 2) Switch into inference/eval mode and compute losses
676+ # Switch into inference/eval mode and compute losses
642677 with torch .inference_mode ():
643678 self .eval ()
644679 optimizer .eval () if hasattr (optimizer , "eval" ) else None
@@ -670,5 +705,121 @@ def validate(
670705 self .train ()
671706 optimizer .train () if hasattr (optimizer , "train" ) else None
672707
708+ def setup_optimizer_and_scheduler (
709+ self ,
710+ epochs : int ,
711+ ) -> tuple [torch .optim .Optimizer , torch .optim .lr_scheduler ._LRScheduler ]:
712+ """
713+ Set up optimizer and scheduler based on self.config.scheduler and self.config.optimizer.
714+ Supports "adamw", "sgd" optimizers and "schedulefree", "cosine", "poly" schedulers.
715+ Patches standard optimizers so that .train() and .eval() exist as no-ops.
716+ Patches ScheduleFree optimizers to have a no-op scheduler.step().
717+ For ScheduleFree optimizers, use lr warmup for the first 1% of epochs.
718+ For Poly scheduler, use a power decay based on self.config.poly_power.
719+ For Cosine scheduler, use a minimum learning rate defined by self.config.eta_min.
720+
721+ Args:
722+ epochs (int): The number of epochs the training will run for.
723+
724+ Returns:
725+ tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]:
726+ The optimizer and scheduler instances.
727+ Raises:
728+ ValueError: If an unknown optimizer or scheduler is specified in the config.
729+ """
730+ scheduler_name = self .config .scheduler .lower ()
731+ optimizer_name = self .config .optimizer .lower ()
732+
733+ class DummyScheduler :
734+ def step (self , * args , ** kwargs ):
735+ pass
736+
737+ def state_dict (self ):
738+ return {}
739+
740+ def load_state_dict (self , state_dict ):
741+ pass
742+
743+ # create optimizer
744+ if optimizer_name == "adamw" :
745+ if scheduler_name == "schedulefree" :
746+ from schedulefree import AdamWScheduleFree
747+
748+ optimizer = AdamWScheduleFree (
749+ self .parameters (),
750+ lr = self .config .learning_rate ,
751+ weight_decay = self .config .regularization_factor ,
752+ warmup_steps = max (1 , epochs // 100 ),
753+ )
754+ else :
755+ from torch .optim import AdamW
756+
757+ optimizer = AdamW (
758+ self .parameters (),
759+ lr = self .config .learning_rate ,
760+ weight_decay = self .config .regularization_factor ,
761+ )
762+ elif optimizer_name == "sgd" :
763+ momentum = self .config .momentum
764+ if scheduler_name == "schedulefree" :
765+ from schedulefree import SGDScheduleFree
766+
767+ optimizer = SGDScheduleFree (
768+ self .parameters (),
769+ lr = self .config .learning_rate ,
770+ weight_decay = self .config .regularization_factor ,
771+ momentum = momentum ,
772+ warmup_steps = max (1 , epochs // 100 ),
773+ )
774+ else :
775+ from torch .optim import SGD
776+
777+ optimizer = SGD (
778+ self .parameters (),
779+ lr = self .config .learning_rate ,
780+ weight_decay = self .config .regularization_factor ,
781+ momentum = momentum ,
782+ )
783+ else :
784+ raise ValueError (f"Unknown optimizer '{ self .config .optimizer } '" )
785+
786+ # Patch optimizer to have no-op train() and eval(), if not present
787+ if not hasattr (optimizer , "train" ):
788+
789+ def _opt_train ():
790+ pass
791+
792+ optimizer .train = _opt_train
793+ if not hasattr (optimizer , "eval" ):
794+
795+ def _opt_eval ():
796+ pass
797+
798+ optimizer .eval = _opt_eval
799+
800+ # create scheduler
801+ if scheduler_name == "schedulefree" :
802+ scheduler = DummyScheduler ()
803+ elif scheduler_name == "cosine" :
804+ from torch .optim .lr_scheduler import CosineAnnealingLR
805+
806+ eta_min = self .config .eta_min
807+ scheduler = CosineAnnealingLR (
808+ optimizer ,
809+ T_max = epochs ,
810+ eta_min = eta_min ,
811+ )
812+ elif scheduler_name == "poly" :
813+ from torch .optim .lr_scheduler import LambdaLR
814+
815+ power = self .config .poly_power
816+ scheduler = LambdaLR (
817+ optimizer , lr_lambda = lambda epoch : (1 - epoch / float (epochs )) ** power
818+ )
819+ else :
820+ raise ValueError (f"Unknown scheduler '{ self .config .scheduler } '" )
821+
822+ return optimizer , scheduler
823+
673824
674825SurrogateModel = TypeVar ("SurrogateModel" , bound = AbstractSurrogateModel )
0 commit comments