Skip to content

Commit 5e3b735

Browse files
Merge pull request #42 from robin-janssen/add-cloud-dataset
Add cloud dataset
2 parents 33e7a21 + f999399 commit 5e3b735

35 files changed

Lines changed: 484 additions & 232 deletions

codes/benchmark/bench_fcts.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from tabulate import tabulate
88
from torch.utils.data import DataLoader
99

10-
from codes.utils import check_and_load_data
10+
from codes.utils import batch_factor_to_float, check_and_load_data
1111

1212
from .bench_plots import inference_time_bar_plot # int_ext_sparse,
1313
from .bench_plots import ( # plot_generalization_errors,; rel_errors_and_uq,
@@ -98,7 +98,13 @@ def run_benchmark(surr_name: str, surrogate_class, conf: dict) -> dict[str, Any]
9898
n_quantities = train_data.shape[2]
9999
n_test_samples = n_timesteps * val_data.shape[0]
100100
n_params = train_params.shape[1] if train_params is not None else 0
101-
model = surrogate_class(device, n_quantities, n_timesteps, n_params, model_config)
101+
model = surrogate_class(
102+
device=device,
103+
n_quantities=n_quantities,
104+
n_timesteps=n_timesteps,
105+
n_parameters=n_params,
106+
config=model_config,
107+
)
102108

103109
# Placeholder for metrics
104110
metrics = {}
@@ -231,7 +237,10 @@ def evaluate_accuracy(
231237
# Calculate relative errors
232238
absolute_errors = np.abs(preds - targets)
233239
mean_absolute_error = np.mean(absolute_errors)
234-
relative_errors = np.abs(absolute_errors / targets)
240+
relative_error_threshold = float(conf.get("relative_error_threshold", 0.0))
241+
relative_errors = np.abs(
242+
absolute_errors / np.maximum(np.abs(targets), relative_error_threshold)
243+
)
235244

236245
# Plot relative errors over time
237246
plot_relative_errors_over_time(
@@ -729,13 +738,17 @@ def evaluate_batchsize(
729738
dict: A dictionary containing batch size training metrics.
730739
"""
731740
training_id = conf["training_id"]
732-
batch_sizes = conf["batch_scaling"]["sizes"].copy()
741+
batch_factors = conf["batch_scaling"]["sizes"].copy()
733742
batch_metrics = {}
734743

735744
# Identify the batch size of the main model
736745
model_idx = conf["surrogates"].index(surr_name)
737746
main_batch_size = conf["batch_size"][model_idx]
738747

748+
batch_sizes = [
749+
int(main_batch_size * batch_factor_to_float(bf)) for bf in batch_factors
750+
]
751+
739752
# Add main batch size to the list of batch sizes
740753
if main_batch_size not in batch_sizes:
741754
batch_sizes.append(main_batch_size)
@@ -840,12 +853,15 @@ def evaluate_UQ(
840853
errors_time = np.mean(errors, axis=(0, 2))
841854
avg_correlation, _ = pearsonr(errors.flatten(), preds_std.flatten())
842855
preds_std_time = np.mean(preds_std, axis=(0, 2))
843-
rel_errors = np.abs(errors / targets)
856+
rel_error_threshold = float(conf.get("relative_error_threshold", 0.0))
857+
rel_errors = np.abs(errors / np.maximum(np.abs(targets), rel_error_threshold))
844858

845859
# Compute a target-weighted, signed difference between predicted uncertainty and error.
846860
# Negative values indicate overconfidence (PU is too low compared to error),
847861
# positive values indicate underconfidence.
848-
weighted_diff = (preds_std - errors) / targets
862+
weighted_diff = (preds_std - errors) / np.maximum(
863+
np.abs(targets), rel_error_threshold
864+
)
849865

850866
# Plots (existing UQ plots)
851867
plot_example_predictions_with_uncertainty(
@@ -971,7 +987,11 @@ def compare_main_losses(metrics: dict, config: dict) -> None:
971987
n_params = metrics[surr_name]["n_params"]
972988
model_config = get_model_config(surr_name, config)
973989
model = surrogate_class(
974-
device, n_quantities, n_timesteps, n_params, model_config
990+
device=device,
991+
n_quantities=n_quantities,
992+
n_timesteps=n_timesteps,
993+
n_parameters=n_params,
994+
config=model_config,
975995
)
976996

977997
def load_losses(model_identifier: str):

codes/benchmark/bench_plots.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from matplotlib.gridspec import GridSpec
88
from scipy.ndimage import gaussian_filter1d
99

10+
from codes.utils import batch_factor_to_float
11+
1012
from .bench_utils import format_time
1113

1214
# Utility functions for plotting
@@ -157,6 +159,9 @@ def plot_relative_errors_over_time(
157159
plt.xlabel("Time")
158160
plt.ylabel("Relative Error")
159161
plt.xlim(timesteps[0], timesteps[-1])
162+
plt.ylim(bottom=1e-8)
163+
if conf["dataset"]["log_timesteps"]:
164+
plt.xscale("log")
160165
if show_title:
161166
plt.title(title)
162167
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
@@ -332,6 +337,8 @@ def plot_average_errors_over_time(
332337
plt.xlim(timesteps[0], timesteps[-1])
333338
plt.ylabel("Mean Absolute Error")
334339
plt.yscale("log")
340+
if conf["dataset"]["log_timesteps"]:
341+
plt.xscale("log")
335342
title = f"Mean Absolute Errors over Time ({mode.capitalize()}, {surr_name})"
336343
filename = f"{mode}_errors_over_time.png"
337344

@@ -450,6 +457,8 @@ def plot_example_mode_predictions(
450457

451458
# Set the x-axis limits based on the timesteps array
452459
ax.set_xlim(timesteps.min(), timesteps.max())
460+
if conf["dataset"]["log_timesteps"]:
461+
ax.set_xscale("log")
453462

454463
# Add a single x-axis label to the bottom of the figure
455464
fig.text(0.5, 0.04, "Time", ha="center", va="center", fontsize=12)
@@ -471,10 +480,10 @@ def plot_example_mode_predictions(
471480

472481
# Set the overall title with details depending on the mode
473482
if mode == "interpolation":
474-
title = f"DeepEnsemble: Example Predictions (Interpolation, {surr_name})\n"
483+
title = f"Interpolation: Example Predictions (Interpolation, {surr_name})\n"
475484
extra_info = f"Sample Index: {example_idx}, Training Interval: {metric}"
476485
elif mode == "extrapolation":
477-
title = f"DeepEnsemble: Example Predictions (Extrapolation, {surr_name})\n"
486+
title = f"Extrapolation: Example Predictions (Extrapolation, {surr_name})\n"
478487
extra_info = f"Sample Index: {example_idx}, Cutoff Timestep: {metric}"
479488
else:
480489
raise ValueError(
@@ -589,6 +598,8 @@ def plot_example_predictions_with_uncertainty(
589598

590599
# Set the x limit exactly from the lowest to the highest timestep
591600
ax.set_xlim(timesteps.min(), timesteps.max())
601+
if conf["dataset"]["log_timesteps"]:
602+
ax.set_xscale("log")
592603

593604
# Add a single x-axis label to the bottom plot
594605
fig.text(0.5, 0.04, "Time", ha="center", va="center", fontsize=12)
@@ -656,6 +667,8 @@ def plot_average_uncertainty_over_time(
656667
plt.xlabel("Time")
657668
plt.ylabel("Average Uncertainty / Mean Absolute Error")
658669
plt.xlim(timesteps[0], timesteps[-1])
670+
if conf["dataset"]["log_timesteps"]:
671+
plt.xscale("log")
659672
if show_title:
660673
plt.title("Average Uncertainty and Mean Absolute Error Over Time")
661674
plt.legend()
@@ -835,10 +848,16 @@ def load_losses(model_identifier: str):
835848

836849
# Batchsize losses
837850
if conf["batch_scaling"]["enabled"]:
838-
batch_sizes = conf["batch_scaling"]["sizes"]
851+
batch_factors = conf["batch_scaling"]["sizes"]
839852
batch_train_losses = []
840853
batch_test_losses = []
841-
for batch_size in batch_sizes:
854+
batch_sizes = []
855+
surr_index = conf["surrogates"].index(surr_name)
856+
main_model_bs = conf["batch_size"][surr_index]
857+
for batch_factor in batch_factors:
858+
batch_factor = batch_factor_to_float(batch_factor)
859+
batch_size = int(main_model_bs * batch_factor)
860+
batch_sizes.append(batch_size)
842861
train_loss, test_loss, epochs = load_losses(
843862
f"{surr_name.lower()}_batchsize_{batch_size}"
844863
)
@@ -966,7 +985,9 @@ def plot_error_distribution_per_quantity(
966985
fig.align_ylabels()
967986

968987
plt.xscale("log") # Log scale for error magnitudes
969-
plt.xlim(10**x_min, 10**x_max) # Set x-axis range based on log-space calculations
988+
plt.xlim(
989+
np.maximum(10**x_min, 1e-8), 10**x_max
990+
) # Set x-axis range based on log-space calculations
970991
plt.xlabel("Relative Error")
971992
if show_title:
972993
if num_plots > 1:
@@ -1423,6 +1444,8 @@ def plot_relative_errors(
14231444
plt.yscale("log")
14241445
if show_title:
14251446
plt.title("Comparison of Relative Errors Over Time")
1447+
if config["dataset"]["log_timesteps"]:
1448+
plt.xscale("log")
14261449
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
14271450

14281451
if save and config:
@@ -1483,6 +1506,8 @@ def plot_uncertainty_over_time_comparison(
14831506
plt.xlim(timesteps[0], timesteps[-1])
14841507
plt.ylabel("Uncertainty / MAE")
14851508
plt.yscale("log")
1509+
if config["dataset"]["log_timesteps"]:
1510+
plt.xscale("log")
14861511
if show_title:
14871512
plt.title("Comparison of Predictive Uncertainty and True MAE over Time")
14881513
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
@@ -2266,7 +2291,7 @@ def plot_error_distribution_comparative(
22662291
)
22672292

22682293
plt.xscale("log") # Log scale for error magnitudes
2269-
plt.xlim(10**x_min, 10**x_max) # Set x-axis range based on log-space calculations
2294+
plt.xlim(np.maximum(10**x_min, 1e-8), 10**x_max) # Set x-axis range
22702295

22712296
if mode == "main":
22722297
title = "Distribution of Surrogate Relative Errors"

codes/benchmark/bench_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -452,10 +452,7 @@ def write_metrics_to_yaml(surr_name: str, conf: dict, metrics: dict) -> None:
452452
write_metrics = convert_to_standard_types(write_metrics)
453453

454454
# Make results directory
455-
try:
456-
os.makedirs(f"results/{conf['training_id']}")
457-
except FileExistsError:
458-
pass
455+
os.makedirs(f"results/{conf['training_id']}", exist_ok=True)
459456

460457
with open(
461458
f"results/{conf['training_id']}/{surr_name.lower()}_metrics.yaml",

0 commit comments

Comments
 (0)