2020 plot_error_correlation_heatmap ,
2121 plot_error_distribution_comparative ,
2222 plot_error_distribution_per_quantity ,
23+ plot_error_percentiles_over_time ,
24+ plot_example_iterative_predictions ,
2325 plot_example_mode_predictions ,
2426 plot_example_predictions_with_uncertainty ,
2527 plot_generalization_error_comparison ,
2628 plot_loss_comparison ,
2729 plot_loss_comparison_equal ,
2830 plot_loss_comparison_train_duration ,
2931 plot_relative_errors ,
30- plot_relative_errors_over_time ,
3132 plot_surr_losses ,
3233 plot_uncertainty_confidence ,
3334 plot_uncertainty_over_time_comparison ,
@@ -137,6 +138,13 @@ def run_benchmark(surr_name: str, surrogate_class, conf: dict) -> dict[str, Any]
137138 model , surr_name , timesteps , val_loader , conf , labels
138139 )
139140
141+ if conf ["iterative" ]:
142+ # Iterative training benchmark
143+ print ("Running iterative training benchmark..." )
144+ metrics ["iterative" ] = evaluate_iterative_predictions (
145+ model , surr_name , timesteps , val_loader , conf , labels
146+ )
147+
140148 # Gradients benchmark
141149 if conf ["gradients" ]:
142150 print ("Running gradients benchmark..." )
@@ -207,20 +215,22 @@ def evaluate_accuracy(
207215 labels : list | None = None ,
208216) -> dict [str , Any ]:
209217 """
210- Evaluate the accuracy of the surrogate model.
211- quantitiesquantities
212- Args:
213- model: Instance of the surrogate model class.
214- surr_name (str): The name of the surrogate model.
215- timesteps (np.ndarray): The timesteps array.
216- test_loader (DataLoader): The DataLoader object containing the test data.
217- conf (dict): The configuration dictionary.
218- labels (list, optional): The labels for the quantities.
219-
220- Returns:
221- dict: A dictionary containing accuracy metrics.
218+ Evaluate the accuracy of the surrogate model.
219+
220+ Args:
221+ model: Instance of the surrogate model class.
222+ surr_name (str): The name of the surrogate model.
223+ timesteps (np.ndarray): The timesteps array.
224+ test_loader (DataLoader): The DataLoader object containing the test data.
225+ conf (dict): The configuration dictionary.
226+ labels (list, optional): The labels for the quantities.
227+ percentile (int, optional): The percentile for error metrics.
228+
229+ Returns:
230+ dict: A dictionary containing accuracy metrics.
222231 """
223232 training_id = conf ["training_id" ]
233+ percentile = conf .get ("error_percentile" , 99 )
224234
225235 # Load the model
226236 model .load (training_id , surr_name , model_identifier = f"{ surr_name .lower ()} _main" )
@@ -229,27 +239,56 @@ def evaluate_accuracy(
229239 model_index = conf ["surrogates" ].index (surr_name )
230240 n_epochs = conf ["epochs" ][model_index ]
231241
232- # Use the model's predict method
233- criterion = torch .nn .MSELoss ()
234- preds , targets = model .predict (data_loader = test_loader )
235- mean_squared_error = criterion (preds , targets ).item () # / torch.numel(preds)
242+ # Obtain log-space predictions and targets
243+ preds , targets = model .predict (data_loader = test_loader , leave_log = True )
244+ preds , targets = preds .detach ().cpu ().numpy (), targets .detach ().cpu ().numpy ()
245+
246+ # Compute log-space error metrics
247+ absolute_errors_log = np .abs (preds - targets )
248+ root_mean_squared_error_log = np .sqrt (np .mean (absolute_errors_log ** 2 ))
249+ median_absolute_error_log = np .median (absolute_errors_log )
250+ mean_absolute_error_log = np .mean (absolute_errors_log )
251+ percentile_absolute_error_log = np .percentile (absolute_errors_log , percentile )
252+
253+ # Obtain real-space predictions and targets
254+ preds , targets = model .predict (data_loader = test_loader , leave_log = False )
236255 preds , targets = preds .detach ().cpu ().numpy (), targets .detach ().cpu ().numpy ()
237256
238- # Calculate relative errors
257+ # Compute real-space error metrics
239258 absolute_errors = np .abs (preds - targets )
240- mean_absolute_error = np .mean (absolute_errors )
259+ root_mean_squared_error_real = np .sqrt (np .mean (absolute_errors ** 2 ))
260+ median_absolute_error_real = np .median (absolute_errors )
261+ mean_absolute_error_real = np .mean (absolute_errors )
262+ percentile_absolute_error_real = np .percentile (absolute_errors , percentile )
263+
264+ # Additional real-space errors: Relative error
241265 relative_error_threshold = float (conf .get ("relative_error_threshold" , 0.0 ))
242266 relative_errors = np .abs (
243267 absolute_errors / np .maximum (np .abs (targets ), relative_error_threshold )
244268 )
269+ median_relative_error = np .median (relative_errors )
270+ mean_relative_error = np .mean (relative_errors )
271+ percentile_relative_error = np .percentile (relative_errors , percentile )
245272
246- # Plot relative errors over time
247- plot_relative_errors_over_time (
273+ plot_error_percentiles_over_time (
248274 surr_name ,
249275 conf ,
250276 relative_errors ,
251277 timesteps ,
252278 title = f"Relative Errors over Time for { surr_name } " ,
279+ mode = "relative" ,
280+ save = True ,
281+ show_title = TITLE ,
282+ )
283+
284+ plot_error_percentiles_over_time (
285+ surr_name ,
286+ conf ,
287+ absolute_errors_log ,
288+ timesteps ,
289+ title = r"$\Delta dex$ (Absolute Log-Space) Errors over Time for "
290+ + f"{ surr_name } " ,
291+ mode = "deltadex" ,
253292 save = True ,
254293 show_title = TITLE ,
255294 )
@@ -266,21 +305,161 @@ def evaluate_accuracy(
266305
267306 # Store metrics
268307 accuracy_metrics = {
269- "mean_squared_error" : mean_squared_error ,
270- "mean_absolute_error" : mean_absolute_error ,
271- "mean_relative_error" : np .mean (relative_errors ),
272- "median_relative_error" : np .median (relative_errors ),
273- "max_relative_error" : np .max (relative_errors ),
274- "min_relative_error" : np .min (relative_errors ),
275- "absolute_errors" : absolute_errors ,
276- "relative_errors" : relative_errors ,
308+ "root_mean_squared_error_log" : root_mean_squared_error_log ,
309+ "median_absolute_error_log" : median_absolute_error_log ,
310+ "mean_absolute_error_log" : mean_absolute_error_log ,
311+ "percentile_absolute_error_log" : percentile_absolute_error_log ,
312+ "root_mean_squared_error_real" : root_mean_squared_error_real ,
313+ "median_absolute_error_real" : median_absolute_error_real ,
314+ "mean_absolute_error_real" : mean_absolute_error_real ,
315+ "percentile_absolute_error_real" : percentile_absolute_error_real ,
316+ "median_relative_error" : median_relative_error ,
317+ "mean_relative_error" : mean_relative_error ,
318+ "percentile_relative_error" : percentile_relative_error ,
319+ "error_percentile" : percentile ,
277320 "main_model_training_time" : train_time ,
278321 "main_model_epochs" : n_epochs ,
322+ "absolute_errors" : absolute_errors ,
323+ "relative_errors" : relative_errors ,
279324 }
280325
281326 return accuracy_metrics
282327
283328
329+ def evaluate_iterative_predictions (
330+ model ,
331+ surr_name : str ,
332+ timesteps : np .ndarray ,
333+ val_loader : DataLoader ,
334+ conf : dict ,
335+ labels : list | None = None ,
336+ ) -> dict [str , Any ]:
337+ """
338+ Evaluate the iterative predictions of the surrogate model.
339+
340+ Returns the same set of error metrics as evaluate_accuracy, but over the
341+ full trajectory built by re-feeding the last prediction as the next initial state.
342+ """
343+ # load trained model
344+ training_id = conf ["training_id" ]
345+ model .load (training_id , surr_name , model_identifier = f"{ surr_name .lower ()} _main" )
346+
347+ # get full ground truth (targets) and ignore one-shot preds
348+ full_preds , targets = model .predict (
349+ data_loader = val_loader , leave_log = True , leave_norm = True
350+ )
351+ targets = targets .detach ().cpu ().numpy ()
352+ n_samples , n_timesteps , n_quantities = targets .shape
353+
354+ original_n_timesteps = model .n_timesteps
355+
356+ # how many timesteps per chunk
357+ iter_interval = 10 # conf["iterative"]["interval"]
358+ # batch size same as in run_benchmark
359+ surr_idx = conf ["surrogates" ].index (surr_name )
360+ if isinstance (conf ["batch_size" ], list ):
361+ batch_size = conf ["batch_size" ][surr_idx ]
362+ else :
363+ batch_size = conf ["batch_size" ]
364+
365+ # container for the piecewise predictions
366+ iterative_preds = np .zeros_like (targets )
367+
368+ # number of chunks
369+ n_chunks = (n_timesteps + iter_interval - 1 ) // iter_interval
370+
371+ # create timesteps array for the iterative predictions
372+ timesteps_full = np .linspace (0 , 1 , original_n_timesteps )
373+
374+ for i in range (n_chunks ):
375+ start = i * iter_interval
376+ end = min (start + iter_interval , n_timesteps )
377+ model .n_timesteps = (
378+ end - start + 1
379+ ) # set the number of timesteps for this chunk
380+
381+ # choose initial state
382+ if i == 0 :
383+ init_state = targets [:, 0 , :]
384+ else :
385+ init_state = iterative_preds [:, start - 1 , :]
386+
387+ # build dummy dataset: only first slice matters for prepare_data
388+ ds = np .zeros ((n_samples , model .n_timesteps , n_quantities ))
389+ ds [:, 0 , :] = init_state
390+
391+ # only need the "train" loader for prediction
392+ dt = timesteps_full [: model .n_timesteps ]
393+ train_loader , _ , _ = model .prepare_data (
394+ dataset_train = ds ,
395+ dataset_test = None ,
396+ dataset_val = None ,
397+ timesteps = dt ,
398+ batch_size = batch_size ,
399+ shuffle = False ,
400+ dataset_train_params = None ,
401+ dataset_test_params = None ,
402+ dataset_val_params = None ,
403+ dummy_timesteps = False ,
404+ )
405+
406+ # predict this chunk and insert into the global array
407+ preds_chunk , _ = model .predict (
408+ data_loader = train_loader , leave_log = True , leave_norm = True
409+ )
410+ iterative_preds [:, start :end , :] = (
411+ preds_chunk [:, 1 : model .n_timesteps , :].detach ().cpu ().numpy ()
412+ )
413+
414+ iterative_preds_log = model .denormalize (iterative_preds , leave_log = True )
415+ targets_log = model .denormalize (targets , leave_log = True )
416+ iterative_preds = model .denormalize (iterative_preds )
417+ full_preds = model .denormalize (full_preds .detach ().cpu ().numpy ())
418+ targets = model .denormalize (targets )
419+
420+ # compute error metrics
421+ errors = iterative_preds - targets
422+ abs_errors = np .abs (errors )
423+ mse = float (np .mean (errors ** 2 ))
424+ mae = float (np .mean (abs_errors ))
425+
426+ # compute log-space errors
427+ abs_errors_log = np .abs (iterative_preds_log - targets_log )
428+ rmse_log = float (np .mean (abs_errors_log ** 2 ))
429+ mae_log = float (np .mean (abs_errors_log ))
430+ percentile = conf .get ("error_percentile" , 99 )
431+ percentile_abs_error_log = float (np .percentile (abs_errors_log , percentile ))
432+
433+ errors = np .mean (np .abs (iterative_preds - targets ), axis = (1 , 2 ))
434+ example_idx = int (np .argsort (np .abs (errors - np .median (errors )))[0 ])
435+
436+ # Restore original number of timesteps
437+ model .n_timesteps = original_n_timesteps
438+
439+ plot_example_iterative_predictions (
440+ surr_name ,
441+ conf ,
442+ iterative_preds ,
443+ full_preds ,
444+ targets ,
445+ timesteps ,
446+ iter_interval = iter_interval ,
447+ example_idx = example_idx ,
448+ labels = labels ,
449+ save = True ,
450+ show_title = TITLE ,
451+ )
452+
453+ return {
454+ "root_mean_squared_error_log" : rmse_log ,
455+ "mean_absolute_error_log" : mae_log ,
456+ "percentile_absolute_error_log" : percentile_abs_error_log ,
457+ "mean_squared_error" : mse ,
458+ "mean_absolute_error" : mae ,
459+ "absolute_errors" : abs_errors ,
460+ }
461+
462+
284463def evaluate_dynamic_accuracy (
285464 model ,
286465 surr_name : str ,
0 commit comments