Skip to content

Commit 66622ed

Browse files
Merge pull request #44 from robin-janssen/improve-losses-and-objectives
Improve losses and objectives
2 parents 9af1345 + fb91009 commit 66622ed

19 files changed

Lines changed: 1297 additions & 1231 deletions

codes/surrogates/AbstractSurrogate/abstract_config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ class AbstractSurrogateBaseConfig:
2222
poly_power (float): Power for polynomial decay scheduler (used only if scheduler == "poly").
2323
eta_min (float): Multiplier for minimum learning rate for cosine annealing scheduler (used only if scheduler == "cosine").
2424
activation (nn.Module): Activation function used in the model.
25+
loss_function (nn.Module): Loss function used for training.
26+
loss_kwargs (float): Additional arguments for the loss function (used only if loss_function == nn.SmoothL1Loss()).
2527
"""
2628

2729
learning_rate: float = 3e-4
@@ -32,3 +34,17 @@ class AbstractSurrogateBaseConfig:
3234
poly_power: float = 0.9 # Used only if scheduler == "poly"
3335
eta_min: float = 1e-1 # Used only if scheduler == "cosine"
3436
activation: nn.Module = nn.ReLU()
37+
loss_function: nn.Module = nn.MSELoss() # Options: nn.MSELoss(), nn.SmoothL1Loss()
38+
beta: float = 0.0 # Used only if loss_function == nn.SmoothL1Loss()
39+
40+
@property
41+
def loss(self) -> nn.Module:
42+
"""
43+
Returns the loss function to be used for training.
44+
45+
If the loss function is nn.SmoothL1Loss, it returns an instance with the specified beta.
46+
Otherwise, it returns the loss function as is.
47+
"""
48+
if isinstance(self.loss_function, nn.SmoothL1Loss):
49+
return self.loss_function(beta=self.beta)
50+
return self.loss_function()

codes/surrogates/AbstractSurrogate/abstract_surrogate.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -215,23 +215,23 @@ def fit(
215215
pass
216216

217217
def predict(
218-
self, data_loader: DataLoader, denormalize: bool = True
218+
self, data_loader: DataLoader, leave_log: bool = False
219219
) -> tuple[Tensor, Tensor]:
220220
"""
221221
Evaluate the model on the given dataloader.
222222
223223
Args:
224224
data_loader (DataLoader): The DataLoader object containing the data the
225225
model is evaluated on.
226-
denormalize (bool): Whether to denormalize the predictions and targets.
226+
leave_log (bool): If True, do not exponentiate the data even if log10_transform is True.
227227
228228
Returns:
229229
tuple[Tensor, Tensor]: The predictions and targets.
230230
"""
231231
# infer output size
232232
with torch.inference_mode():
233233
dummy_inputs = next(iter(data_loader))
234-
dummy_outputs, _ = self.forward(dummy_inputs)
234+
dummy_outputs, _ = self(dummy_inputs)
235235
batch_size, out_shape = (
236236
dummy_outputs.shape[0],
237237
dummy_outputs.shape[-(dummy_outputs.ndim - 1) :],
@@ -247,7 +247,11 @@ def predict(
247247

248248
with torch.inference_mode():
249249
for inputs in data_loader:
250-
preds, targs = self.forward(inputs)
250+
inputs = [
251+
x.to(self.device, non_blocking=True) if isinstance(x, Tensor) else x
252+
for x in inputs
253+
]
254+
preds, targs = self(inputs)
251255
current_batch_size = preds.shape[0] # get actual batch size
252256
predictions[
253257
processed_samples : processed_samples + current_batch_size, ...
@@ -261,9 +265,8 @@ def predict(
261265
predictions = predictions[:processed_samples, ...]
262266
targets = targets[:processed_samples, ...]
263267

264-
if denormalize:
265-
predictions = self.denormalize(predictions)
266-
targets = self.denormalize(targets)
268+
predictions = self.denormalize(predictions, leave_log=leave_log)
269+
targets = self.denormalize(targets, leave_log=leave_log)
267270

268271
predictions = predictions.reshape(-1, self.n_timesteps, self.n_quantities)
269272
targets = targets.reshape(-1, self.n_timesteps, self.n_quantities)
@@ -499,7 +502,7 @@ def time_pruning(self, current_epoch: int, total_epochs: int) -> None:
499502
optuna.TrialPruned: If the projected runtime exceeds the threshold.
500503
"""
501504
# Define warmup period based on 10% of total epochs.
502-
warmup_epochs = max(50, int(total_epochs * 0.02))
505+
warmup_epochs = max(10, int(total_epochs * 0.02))
503506
if current_epoch < warmup_epochs:
504507
# Do not attempt to prune before the warmup period is complete.
505508
# print(
@@ -645,7 +648,6 @@ def validate(
645648
epoch: int,
646649
train_loader: DataLoader,
647650
test_loader: DataLoader,
648-
criterion: nn.Module,
649651
optimizer: torch.optim.Optimizer,
650652
progress_bar: tqdm,
651653
total_epochs: int,
@@ -665,6 +667,7 @@ def validate(
665667
- self.checkpoint(test_loss, epoch)
666668
667669
Only runs if (epoch % self.update_epochs) == 0.
670+
Main reporting metric is MAE in log10-space (i.e., Δdex). Additionally, MAE in linear space is computed.
668671
"""
669672

670673
# If it's not time to check yet, do nothing.
@@ -679,10 +682,11 @@ def validate(
679682
optimizer.eval() if hasattr(optimizer, "eval") else None
680683

681684
# Compute losses
682-
preds, targets = self.predict(train_loader)
683-
self.train_loss[index] = criterion(preds, targets).item()
685+
preds, targets = self.predict(train_loader, leave_log=True)
686+
self.train_loss[index] = self.L1(preds, targets).item()
687+
preds, targets = self.predict(test_loader, leave_log=True)
688+
self.test_loss[index] = self.L1(preds, targets).item()
684689
preds, targets = self.predict(test_loader)
685-
self.test_loss[index] = criterion(preds, targets).item()
686690
self.MAE[index] = self.L1(preds, targets).item()
687691

688692
progress_bar.set_postfix(
@@ -699,6 +703,12 @@ def validate(
699703
self.optuna_trial.report(self.test_loss[index], step=epoch)
700704
if self.optuna_trial.should_prune():
701705
raise optuna.TrialPruned()
706+
elif np.isinf(self.test_loss[index]) or np.isnan(
707+
self.test_loss[index]
708+
):
709+
raise optuna.TrialPruned(
710+
"Test loss is NaN or Inf, pruning trial."
711+
)
702712

703713
self.checkpoint(self.test_loss[index], epoch)
704714

0 commit comments

Comments
 (0)