Skip to content

Commit 9af1345

Browse files
Merge pull request #43 from robin-janssen/improve-optimizer-and-scheduler
Improve optimizer and scheduler
2 parents 5e3b735 + a44b206 commit 9af1345

24 files changed

Lines changed: 596 additions & 415 deletions

codes/benchmark/bench_fcts.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,12 @@ def run_benchmark(surr_name: str, surrogate_class, conf: dict) -> dict[str, Any]
8686
labels,
8787
) = check_and_load_data(
8888
conf["dataset"]["name"],
89-
verbose=False,
89+
verbose=conf.get("verbose", False),
9090
log=conf["dataset"]["log10_transform"],
9191
log_params=conf.get("log10_transform_params", False),
9292
normalisation_mode=conf["dataset"]["normalise"],
9393
tolerance=conf["dataset"]["tolerance"],
94+
per_species=conf["dataset"].get("normalise_per_species", False),
9495
)
9596

9697
model_config = get_model_config(surr_name, conf)
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
from .surrogates import AbstractSurrogateModel, SurrogateModel
1+
from .abstract_config import AbstractSurrogateBaseConfig
2+
from .abstract_surrogate import AbstractSurrogateModel, SurrogateModel
23

3-
__all__ = ["AbstractSurrogateModel", "SurrogateModel"]
4+
__all__ = ["AbstractSurrogateModel", "SurrogateModel", "AbstractSurrogateBaseConfig"]
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from dataclasses import dataclass
2+
3+
from torch import nn
4+
5+
6+
@dataclass
7+
class AbstractSurrogateBaseConfig:
8+
"""
9+
Base configuration for the AbstractSurrogate model.
10+
11+
This class defines shared attributes and methods for surrogate models.
12+
13+
Attributes:
14+
learning_rate (float): Learning rate for the optimizer.
15+
regularization_factor (float): Regularization coefficient, applied as weight decay.
16+
optimizer (str): Type of optimizer to use. Supported options: adamw, sgd.
17+
momentum (float): Momentum factor for the optimizer (used only if optimizer == "sgd").
18+
scheduler (str): Type of learning rate scheduler to use.
19+
- "schedulefree": Use schedulefree optimizer.
20+
- "cosine": Use cosine annealing scheduler.
21+
- "poly": Use polynomial decay scheduler.
22+
poly_power (float): Power for polynomial decay scheduler (used only if scheduler == "poly").
23+
eta_min (float): Multiplier for minimum learning rate for cosine annealing scheduler (used only if scheduler == "cosine").
24+
activation (nn.Module): Activation function used in the model.
25+
"""
26+
27+
learning_rate: float = 3e-4
28+
regularization_factor: float = 0.0
29+
optimizer: str = "adamw" # Options: "adamw", "sgd"
30+
momentum: float = 0.0 # Used only if optimizer == "sgd"
31+
scheduler: str = "cosine" # Options: "schedulefree", "cosine", "poly"
32+
poly_power: float = 0.9 # Used only if scheduler == "poly"
33+
eta_min: float = 1e-1 # Used only if scheduler == "cosine"
34+
activation: nn.Module = nn.ReLU()

codes/surrogates/AbstractSurrogate/surrogates.py renamed to codes/surrogates/AbstractSurrogate/abstract_surrogate.py

Lines changed: 157 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def fit(
199199
epochs: int,
200200
position: int,
201201
description: str,
202+
multi_objective: bool,
202203
) -> None:
203204
"""
204205
Perform the training of the model. Sets the train_loss and test_loss attributes.
@@ -209,16 +210,20 @@ def fit(
209210
epochs (int): The number of epochs to train the model for.
210211
position (int): The position of the progress bar.
211212
description (str): The description of the progress bar.
213+
multi_objective (bool): Whether the training is multi-objective.
212214
"""
213215
pass
214216

215-
def predict(self, data_loader: DataLoader) -> tuple[Tensor, Tensor]:
217+
def predict(
218+
self, data_loader: DataLoader, denormalize: bool = True
219+
) -> tuple[Tensor, Tensor]:
216220
"""
217221
Evaluate the model on the given dataloader.
218222
219223
Args:
220224
data_loader (DataLoader): The DataLoader object containing the data the
221225
model is evaluated on.
226+
denormalize (bool): Whether to denormalize the predictions and targets.
222227
223228
Returns:
224229
tuple[Tensor, Tensor]: The predictions and targets.
@@ -256,8 +261,9 @@ def predict(self, data_loader: DataLoader) -> tuple[Tensor, Tensor]:
256261
predictions = predictions[:processed_samples, ...]
257262
targets = targets[:processed_samples, ...]
258263

259-
predictions = self.denormalize(predictions)
260-
targets = self.denormalize(targets)
264+
if denormalize:
265+
predictions = self.denormalize(predictions)
266+
targets = self.denormalize(targets)
261267

262268
predictions = predictions.reshape(-1, self.n_timesteps, self.n_quantities)
263269
targets = targets.reshape(-1, self.n_timesteps, self.n_quantities)
@@ -419,7 +425,36 @@ def setup_progress_bar(self, epochs: int, position: int, description: str):
419425

420426
return progress_bar
421427

422-
def denormalize(self, data: Tensor) -> Tensor:
428+
def denormalize(self, data: Tensor, leave_log: bool = False) -> Tensor:
429+
"""
430+
Denormalize the data.
431+
432+
Args:
433+
data (np.ndarray): The data to denormalize.
434+
leave_log (bool): If True, do not exponentiate the data even if log10_transform is True.
435+
436+
Returns:
437+
np.ndarray: The denormalized data.
438+
"""
439+
if self.normalisation is not None:
440+
if self.normalisation["mode"] == "disabled":
441+
...
442+
elif self.normalisation["mode"] == "minmax":
443+
dmax = self.normalisation["max"]
444+
dmin = self.normalisation["min"]
445+
data = data.to("cpu")
446+
data = (data + 1) * (dmax - dmin) / 2 + dmin
447+
elif self.normalisation["mode"] == "standardize":
448+
mean = self.normalisation["mean"]
449+
std = self.normalisation["std"]
450+
data = data * std + mean
451+
452+
if self.normalisation["log10_transform"] and not leave_log:
453+
data = 10**data
454+
455+
return data
456+
457+
def denormalize_old(self, data: Tensor) -> Tensor:
423458
"""
424459
Denormalize the data.
425460
@@ -632,13 +667,13 @@ def validate(
632667
Only runs if (epoch % self.update_epochs) == 0.
633668
"""
634669

635-
# 1) If it's not time to check yet, do nothing.
670+
# If it's not time to check yet, do nothing.
636671
if epoch % self.update_epochs != 0:
637672
return
638673

639674
index = epoch // self.update_epochs
640675

641-
# 2) Switch into inference/eval mode and compute losses
676+
# Switch into inference/eval mode and compute losses
642677
with torch.inference_mode():
643678
self.eval()
644679
optimizer.eval() if hasattr(optimizer, "eval") else None
@@ -670,5 +705,121 @@ def validate(
670705
self.train()
671706
optimizer.train() if hasattr(optimizer, "train") else None
672707

708+
def setup_optimizer_and_scheduler(
709+
self,
710+
epochs: int,
711+
) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]:
712+
"""
713+
Set up optimizer and scheduler based on self.config.scheduler and self.config.optimizer.
714+
Supports "adamw", "sgd" optimizers and "schedulefree", "cosine", "poly" schedulers.
715+
Patches standard optimizers so that .train() and .eval() exist as no-ops.
716+
Patches ScheduleFree optimizers to have a no-op scheduler.step().
717+
For ScheduleFree optimizers, use lr warmup for the first 1% of epochs.
718+
For Poly scheduler, use a power decay based on self.config.poly_power.
719+
For Cosine scheduler, use a minimum learning rate defined by self.config.eta_min.
720+
721+
Args:
722+
epochs (int): The number of epochs the training will run for.
723+
724+
Returns:
725+
tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]:
726+
The optimizer and scheduler instances.
727+
Raises:
728+
ValueError: If an unknown optimizer or scheduler is specified in the config.
729+
"""
730+
scheduler_name = self.config.scheduler.lower()
731+
optimizer_name = self.config.optimizer.lower()
732+
733+
class DummyScheduler:
734+
def step(self, *args, **kwargs):
735+
pass
736+
737+
def state_dict(self):
738+
return {}
739+
740+
def load_state_dict(self, state_dict):
741+
pass
742+
743+
# create optimizer
744+
if optimizer_name == "adamw":
745+
if scheduler_name == "schedulefree":
746+
from schedulefree import AdamWScheduleFree
747+
748+
optimizer = AdamWScheduleFree(
749+
self.parameters(),
750+
lr=self.config.learning_rate,
751+
weight_decay=self.config.regularization_factor,
752+
warmup_steps=max(1, epochs // 100),
753+
)
754+
else:
755+
from torch.optim import AdamW
756+
757+
optimizer = AdamW(
758+
self.parameters(),
759+
lr=self.config.learning_rate,
760+
weight_decay=self.config.regularization_factor,
761+
)
762+
elif optimizer_name == "sgd":
763+
momentum = self.config.momentum
764+
if scheduler_name == "schedulefree":
765+
from schedulefree import SGDScheduleFree
766+
767+
optimizer = SGDScheduleFree(
768+
self.parameters(),
769+
lr=self.config.learning_rate,
770+
weight_decay=self.config.regularization_factor,
771+
momentum=momentum,
772+
warmup_steps=max(1, epochs // 100),
773+
)
774+
else:
775+
from torch.optim import SGD
776+
777+
optimizer = SGD(
778+
self.parameters(),
779+
lr=self.config.learning_rate,
780+
weight_decay=self.config.regularization_factor,
781+
momentum=momentum,
782+
)
783+
else:
784+
raise ValueError(f"Unknown optimizer '{self.config.optimizer}'")
785+
786+
# Patch optimizer to have no-op train() and eval(), if not present
787+
if not hasattr(optimizer, "train"):
788+
789+
def _opt_train():
790+
pass
791+
792+
optimizer.train = _opt_train
793+
if not hasattr(optimizer, "eval"):
794+
795+
def _opt_eval():
796+
pass
797+
798+
optimizer.eval = _opt_eval
799+
800+
# create scheduler
801+
if scheduler_name == "schedulefree":
802+
scheduler = DummyScheduler()
803+
elif scheduler_name == "cosine":
804+
from torch.optim.lr_scheduler import CosineAnnealingLR
805+
806+
eta_min = self.config.eta_min
807+
scheduler = CosineAnnealingLR(
808+
optimizer,
809+
T_max=epochs,
810+
eta_min=eta_min,
811+
)
812+
elif scheduler_name == "poly":
813+
from torch.optim.lr_scheduler import LambdaLR
814+
815+
power = self.config.poly_power
816+
scheduler = LambdaLR(
817+
optimizer, lr_lambda=lambda epoch: (1 - epoch / float(epochs)) ** power
818+
)
819+
else:
820+
raise ValueError(f"Unknown scheduler '{self.config.scheduler}'")
821+
822+
return optimizer, scheduler
823+
673824

674825
SurrogateModel = TypeVar("SurrogateModel", bound=AbstractSurrogateModel)

0 commit comments

Comments
 (0)