@@ -215,23 +215,23 @@ def fit(
215215 pass
216216
217217 def predict (
218- self , data_loader : DataLoader , denormalize : bool = True
218+ self , data_loader : DataLoader , leave_log : bool = False
219219 ) -> tuple [Tensor , Tensor ]:
220220 """
221221 Evaluate the model on the given dataloader.
222222
223223 Args:
224224 data_loader (DataLoader): The DataLoader object containing the data the
225225 model is evaluated on.
226- denormalize (bool): Whether to denormalize the predictions and targets .
226+ leave_log (bool): If True, do not exponentiate the data even if log10_transform is True .
227227
228228 Returns:
229229 tuple[Tensor, Tensor]: The predictions and targets.
230230 """
231231 # infer output size
232232 with torch .inference_mode ():
233233 dummy_inputs = next (iter (data_loader ))
234- dummy_outputs , _ = self . forward (dummy_inputs )
234+ dummy_outputs , _ = self (dummy_inputs )
235235 batch_size , out_shape = (
236236 dummy_outputs .shape [0 ],
237237 dummy_outputs .shape [- (dummy_outputs .ndim - 1 ) :],
@@ -247,7 +247,11 @@ def predict(
247247
248248 with torch .inference_mode ():
249249 for inputs in data_loader :
250- preds , targs = self .forward (inputs )
250+ inputs = [
251+ x .to (self .device , non_blocking = True ) if isinstance (x , Tensor ) else x
252+ for x in inputs
253+ ]
254+ preds , targs = self (inputs )
251255 current_batch_size = preds .shape [0 ] # get actual batch size
252256 predictions [
253257 processed_samples : processed_samples + current_batch_size , ...
@@ -261,9 +265,8 @@ def predict(
261265 predictions = predictions [:processed_samples , ...]
262266 targets = targets [:processed_samples , ...]
263267
264- if denormalize :
265- predictions = self .denormalize (predictions )
266- targets = self .denormalize (targets )
268+ predictions = self .denormalize (predictions , leave_log = leave_log )
269+ targets = self .denormalize (targets , leave_log = leave_log )
267270
268271 predictions = predictions .reshape (- 1 , self .n_timesteps , self .n_quantities )
269272 targets = targets .reshape (- 1 , self .n_timesteps , self .n_quantities )
@@ -499,7 +502,7 @@ def time_pruning(self, current_epoch: int, total_epochs: int) -> None:
499502 optuna.TrialPruned: If the projected runtime exceeds the threshold.
500503 """
501504 # Define warmup period based on 10% of total epochs.
502- warmup_epochs = max (50 , int (total_epochs * 0.02 ))
505+ warmup_epochs = max (10 , int (total_epochs * 0.02 ))
503506 if current_epoch < warmup_epochs :
504507 # Do not attempt to prune before the warmup period is complete.
505508 # print(
@@ -645,7 +648,6 @@ def validate(
645648 epoch : int ,
646649 train_loader : DataLoader ,
647650 test_loader : DataLoader ,
648- criterion : nn .Module ,
649651 optimizer : torch .optim .Optimizer ,
650652 progress_bar : tqdm ,
651653 total_epochs : int ,
@@ -665,6 +667,7 @@ def validate(
665667 - self.checkpoint(test_loss, epoch)
666668
667669 Only runs if (epoch % self.update_epochs) == 0.
670+ Main reporting metric is MAE in log10-space (i.e., Δdex). Additionally, MAE in linear space is computed.
668671 """
669672
670673 # If it's not time to check yet, do nothing.
@@ -679,10 +682,11 @@ def validate(
679682 optimizer .eval () if hasattr (optimizer , "eval" ) else None
680683
681684 # Compute losses
682- preds , targets = self .predict (train_loader )
683- self .train_loss [index ] = criterion (preds , targets ).item ()
685+ preds , targets = self .predict (train_loader , leave_log = True )
686+ self .train_loss [index ] = self .L1 (preds , targets ).item ()
687+ preds , targets = self .predict (test_loader , leave_log = True )
688+ self .test_loss [index ] = self .L1 (preds , targets ).item ()
684689 preds , targets = self .predict (test_loader )
685- self .test_loss [index ] = criterion (preds , targets ).item ()
686690 self .MAE [index ] = self .L1 (preds , targets ).item ()
687691
688692 progress_bar .set_postfix (
@@ -699,6 +703,12 @@ def validate(
699703 self .optuna_trial .report (self .test_loss [index ], step = epoch )
700704 if self .optuna_trial .should_prune ():
701705 raise optuna .TrialPruned ()
706+ elif np .isinf (self .test_loss [index ]) or np .isnan (
707+ self .test_loss [index ]
708+ ):
709+ raise optuna .TrialPruned (
710+ "Test loss is NaN or Inf, pruning trial."
711+ )
702712
703713 self .checkpoint (self .test_loss [index ], epoch )
704714
0 commit comments