Skip to content

Commit c8a3836

Browse files
Merge pull request #47 from robin-janssen/improve-latentneuralode
Improve latentneuralode
2 parents 25e9494 + 6ade2be commit c8a3836

25 files changed

Lines changed: 1319 additions & 911 deletions

codes/benchmark/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
plot_error_correlation_heatmap,
3636
plot_error_distribution_comparative,
3737
plot_error_distribution_per_quantity,
38+
plot_error_percentiles_over_time,
39+
plot_example_iterative_predictions,
3840
plot_example_mode_predictions,
3941
plot_example_predictions_with_uncertainty,
4042
plot_generalization_error_comparison,
@@ -44,7 +46,6 @@
4446
plot_losses,
4547
plot_MAE_comparison,
4648
plot_relative_errors,
47-
plot_relative_errors_over_time,
4849
plot_surr_losses,
4950
plot_uncertainty_confidence,
5051
plot_uncertainty_over_time_comparison,
@@ -69,11 +70,11 @@
6970
get_surrogate,
7071
load_model,
7172
make_comparison_csv,
73+
measure_inference_time,
7274
measure_memory_footprint,
7375
read_yaml_config,
7476
save_table_csv,
7577
write_metrics_to_yaml,
76-
measure_inference_time,
7778
)
7879

7980
__all__ = [
@@ -100,12 +101,12 @@
100101
"tabular_comparison",
101102
"save_plot",
102103
"save_plot_counter",
103-
"plot_relative_errors_over_time",
104104
"plot_dynamic_correlation",
105105
"plot_generalization_errors",
106106
"plot_average_errors_over_time",
107107
"plot_example_predictions_with_uncertainty",
108108
"plot_example_mode_predictions",
109+
"plot_example_iterative_predictions",
109110
"plot_average_uncertainty_over_time",
110111
"plot_uncertainty_vs_errors",
111112
"plot_uncertainty_confidence",
@@ -122,6 +123,7 @@
122123
"plot_error_correlation_heatmap",
123124
"plot_dynamic_correlation_heatmap",
124125
"plot_error_distribution_comparative",
126+
"plot_error_percentiles_over_time",
125127
"plot_comparative_error_correlation_heatmaps",
126128
"plot_comparative_dynamic_correlation_heatmaps",
127129
"get_custom_palette",

codes/benchmark/bench_fcts.py

Lines changed: 208 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020
plot_error_correlation_heatmap,
2121
plot_error_distribution_comparative,
2222
plot_error_distribution_per_quantity,
23+
plot_error_percentiles_over_time,
24+
plot_example_iterative_predictions,
2325
plot_example_mode_predictions,
2426
plot_example_predictions_with_uncertainty,
2527
plot_generalization_error_comparison,
2628
plot_loss_comparison,
2729
plot_loss_comparison_equal,
2830
plot_loss_comparison_train_duration,
2931
plot_relative_errors,
30-
plot_relative_errors_over_time,
3132
plot_surr_losses,
3233
plot_uncertainty_confidence,
3334
plot_uncertainty_over_time_comparison,
@@ -137,6 +138,13 @@ def run_benchmark(surr_name: str, surrogate_class, conf: dict) -> dict[str, Any]
137138
model, surr_name, timesteps, val_loader, conf, labels
138139
)
139140

141+
if conf["iterative"]:
142+
# Iterative training benchmark
143+
print("Running iterative training benchmark...")
144+
metrics["iterative"] = evaluate_iterative_predictions(
145+
model, surr_name, timesteps, val_loader, conf, labels
146+
)
147+
140148
# Gradients benchmark
141149
if conf["gradients"]:
142150
print("Running gradients benchmark...")
@@ -207,20 +215,22 @@ def evaluate_accuracy(
207215
labels: list | None = None,
208216
) -> dict[str, Any]:
209217
"""
210-
Evaluate the accuracy of the surrogate model.
211-
quantitiesquantities
212-
Args:
213-
model: Instance of the surrogate model class.
214-
surr_name (str): The name of the surrogate model.
215-
timesteps (np.ndarray): The timesteps array.
216-
test_loader (DataLoader): The DataLoader object containing the test data.
217-
conf (dict): The configuration dictionary.
218-
labels (list, optional): The labels for the quantities.
219-
220-
Returns:
221-
dict: A dictionary containing accuracy metrics.
218+
Evaluate the accuracy of the surrogate model.
219+
220+
Args:
221+
model: Instance of the surrogate model class.
222+
surr_name (str): The name of the surrogate model.
223+
timesteps (np.ndarray): The timesteps array.
224+
test_loader (DataLoader): The DataLoader object containing the test data.
225+
conf (dict): The configuration dictionary.
226+
labels (list, optional): The labels for the quantities.
227+
percentile (int, optional): The percentile for error metrics.
228+
229+
Returns:
230+
dict: A dictionary containing accuracy metrics.
222231
"""
223232
training_id = conf["training_id"]
233+
percentile = conf.get("error_percentile", 99)
224234

225235
# Load the model
226236
model.load(training_id, surr_name, model_identifier=f"{surr_name.lower()}_main")
@@ -229,27 +239,56 @@ def evaluate_accuracy(
229239
model_index = conf["surrogates"].index(surr_name)
230240
n_epochs = conf["epochs"][model_index]
231241

232-
# Use the model's predict method
233-
criterion = torch.nn.MSELoss()
234-
preds, targets = model.predict(data_loader=test_loader)
235-
mean_squared_error = criterion(preds, targets).item() # / torch.numel(preds)
242+
# Obtain log-space predictions and targets
243+
preds, targets = model.predict(data_loader=test_loader, leave_log=True)
244+
preds, targets = preds.detach().cpu().numpy(), targets.detach().cpu().numpy()
245+
246+
# Compute log-space error metrics
247+
absolute_errors_log = np.abs(preds - targets)
248+
root_mean_squared_error_log = np.sqrt(np.mean(absolute_errors_log**2))
249+
median_absolute_error_log = np.median(absolute_errors_log)
250+
mean_absolute_error_log = np.mean(absolute_errors_log)
251+
percentile_absolute_error_log = np.percentile(absolute_errors_log, percentile)
252+
253+
# Obtain real-space predictions and targets
254+
preds, targets = model.predict(data_loader=test_loader, leave_log=False)
236255
preds, targets = preds.detach().cpu().numpy(), targets.detach().cpu().numpy()
237256

238-
# Calculate relative errors
257+
# Compute real-space error metrics
239258
absolute_errors = np.abs(preds - targets)
240-
mean_absolute_error = np.mean(absolute_errors)
259+
root_mean_squared_error_real = np.sqrt(np.mean(absolute_errors**2))
260+
median_absolute_error_real = np.median(absolute_errors)
261+
mean_absolute_error_real = np.mean(absolute_errors)
262+
percentile_absolute_error_real = np.percentile(absolute_errors, percentile)
263+
264+
# Additional real-space errors: Relative error
241265
relative_error_threshold = float(conf.get("relative_error_threshold", 0.0))
242266
relative_errors = np.abs(
243267
absolute_errors / np.maximum(np.abs(targets), relative_error_threshold)
244268
)
269+
median_relative_error = np.median(relative_errors)
270+
mean_relative_error = np.mean(relative_errors)
271+
percentile_relative_error = np.percentile(relative_errors, percentile)
245272

246-
# Plot relative errors over time
247-
plot_relative_errors_over_time(
273+
plot_error_percentiles_over_time(
248274
surr_name,
249275
conf,
250276
relative_errors,
251277
timesteps,
252278
title=f"Relative Errors over Time for {surr_name}",
279+
mode="relative",
280+
save=True,
281+
show_title=TITLE,
282+
)
283+
284+
plot_error_percentiles_over_time(
285+
surr_name,
286+
conf,
287+
absolute_errors_log,
288+
timesteps,
289+
title=r"$\Delta dex$ (Absolute Log-Space) Errors over Time for "
290+
+ f"{surr_name}",
291+
mode="deltadex",
253292
save=True,
254293
show_title=TITLE,
255294
)
@@ -266,21 +305,161 @@ def evaluate_accuracy(
266305

267306
# Store metrics
268307
accuracy_metrics = {
269-
"mean_squared_error": mean_squared_error,
270-
"mean_absolute_error": mean_absolute_error,
271-
"mean_relative_error": np.mean(relative_errors),
272-
"median_relative_error": np.median(relative_errors),
273-
"max_relative_error": np.max(relative_errors),
274-
"min_relative_error": np.min(relative_errors),
275-
"absolute_errors": absolute_errors,
276-
"relative_errors": relative_errors,
308+
"root_mean_squared_error_log": root_mean_squared_error_log,
309+
"median_absolute_error_log": median_absolute_error_log,
310+
"mean_absolute_error_log": mean_absolute_error_log,
311+
"percentile_absolute_error_log": percentile_absolute_error_log,
312+
"root_mean_squared_error_real": root_mean_squared_error_real,
313+
"median_absolute_error_real": median_absolute_error_real,
314+
"mean_absolute_error_real": mean_absolute_error_real,
315+
"percentile_absolute_error_real": percentile_absolute_error_real,
316+
"median_relative_error": median_relative_error,
317+
"mean_relative_error": mean_relative_error,
318+
"percentile_relative_error": percentile_relative_error,
319+
"error_percentile": percentile,
277320
"main_model_training_time": train_time,
278321
"main_model_epochs": n_epochs,
322+
"absolute_errors": absolute_errors,
323+
"relative_errors": relative_errors,
279324
}
280325

281326
return accuracy_metrics
282327

283328

329+
def evaluate_iterative_predictions(
330+
model,
331+
surr_name: str,
332+
timesteps: np.ndarray,
333+
val_loader: DataLoader,
334+
conf: dict,
335+
labels: list | None = None,
336+
) -> dict[str, Any]:
337+
"""
338+
Evaluate the iterative predictions of the surrogate model.
339+
340+
Returns the same set of error metrics as evaluate_accuracy, but over the
341+
full trajectory built by re-feeding the last prediction as the next initial state.
342+
"""
343+
# load trained model
344+
training_id = conf["training_id"]
345+
model.load(training_id, surr_name, model_identifier=f"{surr_name.lower()}_main")
346+
347+
# get full ground truth (targets) and ignore one-shot preds
348+
full_preds, targets = model.predict(
349+
data_loader=val_loader, leave_log=True, leave_norm=True
350+
)
351+
targets = targets.detach().cpu().numpy()
352+
n_samples, n_timesteps, n_quantities = targets.shape
353+
354+
original_n_timesteps = model.n_timesteps
355+
356+
# how many timesteps per chunk
357+
iter_interval = 10 # conf["iterative"]["interval"]
358+
# batch size same as in run_benchmark
359+
surr_idx = conf["surrogates"].index(surr_name)
360+
if isinstance(conf["batch_size"], list):
361+
batch_size = conf["batch_size"][surr_idx]
362+
else:
363+
batch_size = conf["batch_size"]
364+
365+
# container for the piecewise predictions
366+
iterative_preds = np.zeros_like(targets)
367+
368+
# number of chunks
369+
n_chunks = (n_timesteps + iter_interval - 1) // iter_interval
370+
371+
# create timesteps array for the iterative predictions
372+
timesteps_full = np.linspace(0, 1, original_n_timesteps)
373+
374+
for i in range(n_chunks):
375+
start = i * iter_interval
376+
end = min(start + iter_interval, n_timesteps)
377+
model.n_timesteps = (
378+
end - start + 1
379+
) # set the number of timesteps for this chunk
380+
381+
# choose initial state
382+
if i == 0:
383+
init_state = targets[:, 0, :]
384+
else:
385+
init_state = iterative_preds[:, start - 1, :]
386+
387+
# build dummy dataset: only first slice matters for prepare_data
388+
ds = np.zeros((n_samples, model.n_timesteps, n_quantities))
389+
ds[:, 0, :] = init_state
390+
391+
# only need the "train" loader for prediction
392+
dt = timesteps_full[: model.n_timesteps]
393+
train_loader, _, _ = model.prepare_data(
394+
dataset_train=ds,
395+
dataset_test=None,
396+
dataset_val=None,
397+
timesteps=dt,
398+
batch_size=batch_size,
399+
shuffle=False,
400+
dataset_train_params=None,
401+
dataset_test_params=None,
402+
dataset_val_params=None,
403+
dummy_timesteps=False,
404+
)
405+
406+
# predict this chunk and insert into the global array
407+
preds_chunk, _ = model.predict(
408+
data_loader=train_loader, leave_log=True, leave_norm=True
409+
)
410+
iterative_preds[:, start:end, :] = (
411+
preds_chunk[:, 1 : model.n_timesteps, :].detach().cpu().numpy()
412+
)
413+
414+
iterative_preds_log = model.denormalize(iterative_preds, leave_log=True)
415+
targets_log = model.denormalize(targets, leave_log=True)
416+
iterative_preds = model.denormalize(iterative_preds)
417+
full_preds = model.denormalize(full_preds.detach().cpu().numpy())
418+
targets = model.denormalize(targets)
419+
420+
# compute error metrics
421+
errors = iterative_preds - targets
422+
abs_errors = np.abs(errors)
423+
mse = float(np.mean(errors**2))
424+
mae = float(np.mean(abs_errors))
425+
426+
# compute log-space errors
427+
abs_errors_log = np.abs(iterative_preds_log - targets_log)
428+
rmse_log = float(np.mean(abs_errors_log**2))
429+
mae_log = float(np.mean(abs_errors_log))
430+
percentile = conf.get("error_percentile", 99)
431+
percentile_abs_error_log = float(np.percentile(abs_errors_log, percentile))
432+
433+
errors = np.mean(np.abs(iterative_preds - targets), axis=(1, 2))
434+
example_idx = int(np.argsort(np.abs(errors - np.median(errors)))[0])
435+
436+
# Restore original number of timesteps
437+
model.n_timesteps = original_n_timesteps
438+
439+
plot_example_iterative_predictions(
440+
surr_name,
441+
conf,
442+
iterative_preds,
443+
full_preds,
444+
targets,
445+
timesteps,
446+
iter_interval=iter_interval,
447+
example_idx=example_idx,
448+
labels=labels,
449+
save=True,
450+
show_title=TITLE,
451+
)
452+
453+
return {
454+
"root_mean_squared_error_log": rmse_log,
455+
"mean_absolute_error_log": mae_log,
456+
"percentile_absolute_error_log": percentile_abs_error_log,
457+
"mean_squared_error": mse,
458+
"mean_absolute_error": mae,
459+
"absolute_errors": abs_errors,
460+
}
461+
462+
284463
def evaluate_dynamic_accuracy(
285464
model,
286465
surr_name: str,

0 commit comments

Comments
 (0)