Skip to content

Commit 18388f9

Browse files
Merge pull request #53 from robin-janssen/comparative-plots-MR
Merge changes from paper plots branch back to main branch while ignoring files which are not required
2 parents 6c2835b + f1df561 commit 18388f9

6 files changed

Lines changed: 603 additions & 78 deletions

File tree

codes/benchmark/bench_fcts.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from contextlib import redirect_stdout
23
from typing import Any
34

@@ -378,7 +379,6 @@ def evaluate_iterative_predictions(
378379
# container for the piecewise predictions; seed t=0 with ground truth so errors
379380
# are computed only on actual predictions for t>=1 while keeping shape intact
380381
iterative_preds = np.zeros_like(targets)
381-
iterative_preds[:, 0, :] = targets[:, 0, :]
382382

383383
# number of chunks
384384
n_chunks = (n_timesteps + iter_interval - 1) // iter_interval
@@ -433,14 +433,17 @@ def evaluate_iterative_predictions(
433433
)
434434
# We predict steps 1..(chunk_len-1) relative to the provided init state (index 0).
435435
# Map these to global indices [start+1 .. end] inclusively.
436+
if i == 0:
437+
iterative_preds[:, start : end + 1, :] = preds_chunk[:, : model.n_timesteps, :].detach().cpu().numpy()
436438
iterative_preds[:, start + 1 : end + 1, :] = (
437439
preds_chunk[:, 1 : model.n_timesteps, :].detach().cpu().numpy()
438440
)
439441

440442
iterative_preds_log = model.denormalize(iterative_preds, leave_log=True)
443+
full_preds_log = model.denormalize(full_preds, leave_log=True)
441444
targets_log = model.denormalize(targets, leave_log=True)
442445
iterative_preds = model.denormalize(iterative_preds)
443-
full_preds = model.denormalize(full_preds.detach().cpu().numpy())
446+
full_preds_real = model.denormalize(full_preds.detach().cpu().numpy())
444447
targets = model.denormalize(targets)
445448

446449
# compute error metrics
@@ -466,7 +469,7 @@ def evaluate_iterative_predictions(
466469
surr_name,
467470
conf,
468471
iterative_preds,
469-
full_preds,
472+
full_preds_real,
470473
targets,
471474
timesteps,
472475
iter_interval=iter_interval,
@@ -1583,7 +1586,7 @@ def compare_UQ(all_metrics: dict, config: dict) -> None:
15831586
ensemble_errors,
15841587
ensemble_std,
15851588
config,
1586-
flag_fractions=(0.01, 0.05, 0.10, 0.20, 0.30, 0.40, 0.50),
1589+
flag_fractions=(0, 0.025, 0.05, 0.10, 0.20, 0.30, 0.40, 0.50),
15871590
save=True,
15881591
show_title=True,
15891592
)

codes/benchmark/bench_plots.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1661,6 +1661,7 @@ def plot_errors_over_time(
16611661
elif mode == "iterative":
16621662
# Single backslash inside raw string to render the LaTeX Delta properly
16631663
plt.ylabel(r"Log-MAE ($\Delta dex$)")
1664+
plt.ylim(bottom=0, top=min(np.max(list(mean_errors.values())) * 1.1, 5))
16641665
fname = "iterative_delta_dex_time.png"
16651666
title = "Comparison of Δdex Errors Over Time for Iterative Predictions"
16661667
# Add subtle dashed vertical lines at every n-th timestep if provided and valid
@@ -2035,8 +2036,6 @@ def inference_time_bar_plot(
20352036
# Calculate the upper y-limit to provide space for text
20362037
max_bar = max(means[i] + stds[i] for i in range(len(means)))
20372038
# min_bar = min(means[i] - stds[i] for i in range(len(means)))
2038-
# Temp!
2039-
# ax.set_ylim(min_bar * 0.3, max_bar * 2) # Set limits with some padding
20402039
ax.set_ylim(0, max_bar * 1.2) # Set limits with some padding
20412040

20422041
# Add inference time as text to the bars using the format_time function
@@ -2053,8 +2052,6 @@ def inference_time_bar_plot(
20532052

20542053
ax.set_xlabel("Surrogate Model")
20552054
ax.set_ylabel("Mean Inference Time per Run")
2056-
# Temp!
2057-
# ax.set_yscale("log")
20582055
if show_title:
20592056
ax.set_title("Surrogate Mean Inference Time Comparison")
20602057

@@ -2507,11 +2504,16 @@ def plot_catastrophic_detection_curves(
25072504

25082505
xs, ys = [], []
25092506
for f in flag_fractions:
2510-
unc_thr = np.percentile(u, 100.0 * (1.0 - float(f)))
2511-
flagged = u >= unc_thr
2512-
recall = (flagged & is_cat).sum() / n_cat if n_cat > 0 else 0.0
2513-
xs.append(100.0 * flagged.mean())
2514-
ys.append(100.0 * recall)
2507+
if f <= 0.0:
2508+
xs.append(0.0)
2509+
ys.append(0.0)
2510+
recall = 0.0
2511+
else:
2512+
unc_thr = np.percentile(u, 100.0 * (1.0 - float(f)))
2513+
flagged = u >= unc_thr
2514+
recall = (flagged & is_cat).sum() / n_cat if n_cat > 0 else 0.0
2515+
xs.append(100.0 * flagged.mean())
2516+
ys.append(100.0 * recall)
25152517

25162518
ax.plot(
25172519
xs,
@@ -2988,8 +2990,6 @@ def rel_errors_and_uq(
29882990

29892991
ax1.set_xlabel("Time")
29902992
ax1.set_xlim(timesteps[0], timesteps[-1])
2991-
# Temp!
2992-
# ax1.set_ylim(3e-4, 1)
29932993
ax1.set_ylabel("Relative Error")
29942994
ax1.set_yscale("log")
29952995
ax1.set_title("Comparison of Relative Errors Over Time")
@@ -3020,8 +3020,6 @@ def rel_errors_and_uq(
30203020

30213021
ax2.set_xlabel("Time")
30223022
ax2.set_xlim(timesteps[0], timesteps[-1])
3023-
# Temp!
3024-
# ax2.set_ylim(0, 0.04)
30253023
ax2.set_ylabel("Uncertainty/Absolute Error")
30263024
if show_title:
30273025
ax2.set_title("Comparison of Predictive Uncertainty Over Time")

codes/surrogates/AbstractSurrogate/abstract_surrogate.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,8 +449,10 @@ def denormalize(
449449
Returns:
450450
Tensor | np.ndarray: The denormalized data.
451451
"""
452+
data_type = None
452453
if self.normalisation is not None:
453454
if not leave_norm:
455+
data_type = data.dtype
454456
if self.normalisation["mode"] == "disabled":
455457
...
456458
elif self.normalisation["mode"] == "minmax":
@@ -475,6 +477,13 @@ def denormalize(
475477
if self.normalisation["log10_transform"] and not leave_log:
476478
data = 10**data
477479

480+
# Conserve dtype
481+
if data_type is not None:
482+
if isinstance(data, Tensor):
483+
return data.to(dtype=data_type)
484+
if isinstance(data, np.ndarray):
485+
return data.astype(data_type)
486+
478487
return data
479488

480489
def denormalize_old(self, data: Tensor) -> Tensor:

0 commit comments

Comments
 (0)