Skip to content

Commit 33e7a21

Browse files
Merge pull request #41 from robin-janssen/ode-parameters
Implement architecture variants for ODE parameters
2 parents 7304067 + 4416b49 commit 33e7a21

47 files changed

Lines changed: 2086 additions & 992 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

codes/benchmark/bench_fcts.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,23 +76,34 @@ def run_benchmark(surr_name: str, surrogate_class, conf: dict) -> dict[str, Any]
7676
else:
7777
batch_size = conf["batch_size"]
7878

79-
train_data, test_data, val_data, timesteps, n_train_samples, _, labels = (
80-
check_and_load_data(
81-
conf["dataset"]["name"],
82-
verbose=False,
83-
log=conf["dataset"]["log10_transform"],
84-
normalisation_mode=conf["dataset"]["normalise"],
85-
)
79+
# Load full data and parameters
80+
(
81+
(train_data, test_data, val_data),
82+
(train_params, test_params, val_params),
83+
timesteps,
84+
n_train_samples,
85+
_,
86+
labels,
87+
) = check_and_load_data(
88+
conf["dataset"]["name"],
89+
verbose=False,
90+
log=conf["dataset"]["log10_transform"],
91+
log_params=conf.get("log10_transform_params", False),
92+
normalisation_mode=conf["dataset"]["normalise"],
93+
tolerance=conf["dataset"]["tolerance"],
8694
)
95+
8796
model_config = get_model_config(surr_name, conf)
8897
n_timesteps = train_data.shape[1]
8998
n_quantities = train_data.shape[2]
9099
n_test_samples = n_timesteps * val_data.shape[0]
91-
model = surrogate_class(device, n_quantities, n_timesteps, model_config)
100+
n_params = train_params.shape[1] if train_params is not None else 0
101+
model = surrogate_class(device, n_quantities, n_timesteps, n_params, model_config)
92102

93103
# Placeholder for metrics
94104
metrics = {}
95105
metrics["timesteps"] = timesteps
106+
metrics["n_params"] = n_params
96107

97108
# Create dataloader for the validation data
98109
_, _, val_loader = model.prepare_data(
@@ -101,7 +112,11 @@ def run_benchmark(surr_name: str, surrogate_class, conf: dict) -> dict[str, Any]
101112
dataset_val=val_data,
102113
timesteps=timesteps,
103114
batch_size=batch_size,
104-
shuffle=False,
115+
shuffle=True,
116+
dataset_train_params=train_params,
117+
dataset_test_params=test_params,
118+
dataset_val_params=val_params,
119+
dummy_timesteps=True,
105120
)
106121

107122
# Plot training losses
@@ -953,8 +968,11 @@ def compare_main_losses(metrics: dict, config: dict) -> None:
953968
surrogate_class = get_surrogate(surr_name)
954969
n_timesteps = metrics[surr_name]["timesteps"].shape[0]
955970
n_quantities = metrics[surr_name]["accuracy"]["absolute_errors"].shape[2]
971+
n_params = metrics[surr_name]["n_params"]
956972
model_config = get_model_config(surr_name, config)
957-
model = surrogate_class(device, n_quantities, n_timesteps, model_config)
973+
model = surrogate_class(
974+
device, n_quantities, n_timesteps, n_params, model_config
975+
)
958976

959977
def load_losses(model_identifier: str):
960978
model.load(training_id, surr_name, model_identifier=model_identifier)

codes/benchmark/bench_plots.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def save_plot(
2020
dpi: int = 300,
2121
base_dir: str = "plots", # Base directory for saving plots
2222
increase_count: bool = False, # Whether to increase the count for existing filenames
23-
format: str = "pdf", # Format for saving the plot
23+
format: str = "jpg", # Format for saving the plot
2424
) -> None:
2525
"""
2626
Save the plot to a file, creating necessary directories if they don't exist.
@@ -122,9 +122,9 @@ def plot_relative_errors_over_time(
122122
p99_lower = np.percentile(relative_errors, 0.5, axis=(0, 2))
123123

124124
plt.figure(figsize=(6, 4))
125-
mean_label = f"Mean Error\nMean={mean*100:.2f}%"
125+
mean_label = f"Mean Error\nMean={mean * 100:.2f}%"
126126
plt.plot(timesteps, mean_errors, label=mean_label, color="blue")
127-
median_label = f"Median Error\nMedian={median*100:.2f}%"
127+
median_label = f"Median Error\nMedian={median * 100:.2f}%"
128128
plt.plot(timesteps, median_errors, label=median_label, color="red")
129129

130130
# Shading areas
@@ -816,7 +816,9 @@ def load_losses(model_identifier: str):
816816
uq_train_losses = [main_train_loss]
817817
uq_test_losses = [main_test_loss]
818818
for i in range(n_models - 1):
819-
train_loss, test_loss, epochs = load_losses(f"{surr_name.lower()}_UQ_{i+1}")
819+
train_loss, test_loss, epochs = load_losses(
820+
f"{surr_name.lower()}_UQ_{i + 1}"
821+
)
820822
uq_train_losses.append(train_loss)
821823
uq_test_losses.append(test_loss)
822824
plot_losses(
@@ -1397,7 +1399,7 @@ def plot_relative_errors(
13971399

13981400
for i, surrogate in enumerate(mean_errors.keys()):
13991401
mean = np.mean(mean_errors[surrogate])
1400-
mean_label = f"{surrogate}\nMean = {mean*100:.2f}%"
1402+
mean_label = f"{surrogate}\nMean = {mean * 100:.2f}%"
14011403
plt.plot(
14021404
timesteps,
14031405
mean_errors[surrogate],
@@ -1406,7 +1408,7 @@ def plot_relative_errors(
14061408
linestyle=linestyles[0],
14071409
)
14081410
median = np.mean(median_errors[surrogate])
1409-
median_label = f"{surrogate}\nMedian = {median*100:.2f}%"
1411+
median_label = f"{surrogate}\nMedian = {median * 100:.2f}%"
14101412
plt.plot(
14111413
timesteps,
14121414
median_errors[surrogate],
@@ -2895,7 +2897,7 @@ def rel_errors_and_uq(
28952897

28962898
for i, surrogate in enumerate(mean_errors.keys()):
28972899
mean = np.mean(mean_errors[surrogate])
2898-
mean_label = f"{surrogate} Mean={mean*100:.2f} %"
2900+
mean_label = f"{surrogate} Mean={mean * 100:.2f} %"
28992901
ax1.plot(
29002902
timesteps,
29012903
mean_errors[surrogate],
@@ -2904,7 +2906,7 @@ def rel_errors_and_uq(
29042906
linestyle=linestyles[0],
29052907
)
29062908
median = np.mean(median_errors[surrogate])
2907-
median_label = f"{surrogate} Median={median*100:.2f} %"
2909+
median_label = f"{surrogate} Median={median * 100:.2f} %"
29082910
ax1.plot(
29092911
timesteps,
29102912
median_errors[surrogate],

codes/benchmark/bench_utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import importlib.util
33
import inspect
44
import os
5+
import time
56
from copy import deepcopy
67
from dataclasses import asdict
78

@@ -13,8 +14,6 @@
1314
from codes.surrogates import SurrogateModel, surrogate_classes
1415
from codes.utils import read_yaml_config
1516

16-
import time
17-
1817

1918
def check_surrogate(surrogate: str, conf: dict) -> None:
2019
"""
@@ -219,7 +218,7 @@ def get_required_models_list(surrogate: str, conf: dict) -> list:
219218
if conf["uncertainty"]["enabled"]:
220219
n_models = conf["uncertainty"]["ensemble_size"]
221220
required_models.extend(
222-
[f"{surrogate.lower()}_UQ_{i+1}.pth" for i in range(n_models - 1)]
221+
[f"{surrogate.lower()}_UQ_{i + 1}.pth" for i in range(n_models - 1)]
223222
)
224223

225224
return required_models
@@ -296,7 +295,7 @@ def measure_memory_footprint(model: torch.nn.Module, inputs: tuple) -> dict:
296295

297296
# Prepare inputs: move them to the target device
298297
if isinstance(inputs, (list, tuple)):
299-
inputs = tuple(i.to(device) for i in inputs)
298+
inputs = tuple((i.to(device) if i is not None else i) for i in inputs)
300299
else:
301300
inputs = inputs.to(device)
302301

@@ -640,10 +639,9 @@ def save_table_csv(headers: list, rows: list, config: dict) -> None:
640639
"""
641640
# Convert each cell to a string and remove asterisks
642641
cleaned_rows = [
643-
[str(cell).replace("*", "").strip() for cell in row]
644-
for row in rows
642+
[str(cell).replace("*", "").strip() for cell in row] for row in rows
645643
]
646-
644+
647645
csv_path = f"results/{config['training_id']}/metrics_table.csv"
648646
with open(csv_path, "w", newline="") as f:
649647
writer = csv.writer(f)

codes/surrogates/AbstractSurrogate/surrogates.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import dataclasses
22
import os
3+
import time
34
from abc import ABC, abstractmethod
45
from datetime import datetime
56
from typing import Any, TypeVar
67

78
import numpy as np
9+
import optuna
810
import torch
911
import yaml
1012
from torch import Tensor, nn
@@ -71,7 +73,7 @@ class AbstractSurrogateModel(ABC, nn.Module):
7173
model_name: str,
7274
subfolder: str,
7375
training_id: str,
74-
data_params: dict,
76+
data_info: dict,
7577
) -> None:
7678
Saves the model to disk.
7779
@@ -99,6 +101,7 @@ def __init__(
99101
device: str | None = None,
100102
n_quantities: int = 29,
101103
n_timesteps: int = 100,
104+
n_parameters: int = 0,
102105
config: dict | None = None,
103106
):
104107
super().__init__()
@@ -109,6 +112,7 @@ def __init__(
109112
self.device = device
110113
self.n_quantities = n_quantities
111114
self.n_timesteps = n_timesteps
115+
self.n_parameters = n_parameters
112116
self.L1 = nn.L1Loss()
113117
self.config = config if config is not None else {}
114118
self.train_duration = None
@@ -265,7 +269,7 @@ def save(
265269
model_name (str): The name of the model.
266270
subfolder (str): The subfolder to save the model in.
267271
training_id (str): The training identifier.
268-
data_params (dict): The data parameters.
272+
data_info (dict): The data parameters.
269273
"""
270274

271275
# Make the model directory
@@ -329,7 +333,7 @@ def save(
329333

330334
save_attributes = {
331335
k: v
332-
for k, v in self.__dict__.items()
336+
for k, v in self.__dict__.copy().items()
333337
if k != "state_dict" and not k.startswith("_")
334338
}
335339
model_dict = {"state_dict": self.state_dict(), "attributes": save_attributes}
@@ -392,6 +396,7 @@ def setup_progress_bar(self, epochs: int, position: int, description: str):
392396
Returns:
393397
tqdm: The progress bar.
394398
"""
399+
395400
bar_format = "{l_bar}{bar}| {n_fmt:>5}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt} {postfix}]"
396401
progress_bar = tqdm(
397402
range(epochs),
@@ -401,6 +406,9 @@ def setup_progress_bar(self, epochs: int, position: int, description: str):
401406
bar_format=bar_format,
402407
)
403408

409+
# Only used for time_pruning in multi objective optimisation
410+
self._trial_start_time = time.time()
411+
404412
return progress_bar
405413

406414
def denormalize(self, data: Tensor) -> Tensor:
@@ -430,5 +438,64 @@ def denormalize(self, data: Tensor) -> Tensor:
430438

431439
return data
432440

441+
def time_pruning(self, current_epoch: int, total_epochs: int) -> None:
442+
"""
443+
Determine whether a trial should be pruned based on projected runtime,
444+
but only after a warmup period (10% of the total epochs).
445+
446+
Warmup: Do not prune if current_epoch is less than warmup_epochs.
447+
After warmup, compute the average epoch time, extrapolate the total runtime,
448+
and retrieve the threshold (runtime_threshold) from the study's user attributes.
449+
If the projected runtime exceeds the threshold, raise an optuna.TrialPruned exception.
450+
451+
Args:
452+
current_epoch (int): The current epoch count.
453+
total_epochs (int): The planned total number of epochs.
454+
455+
Raises:
456+
optuna.TrialPruned: If the projected runtime exceeds the threshold.
457+
"""
458+
# Define warmup period based on 10% of total epochs.
459+
warmup_epochs = max(50, int(total_epochs * 0.02))
460+
if current_epoch < warmup_epochs:
461+
# Do not attempt to prune before the warmup period is complete.
462+
# print(
463+
# f"[time_pruning] Warmup period: {current_epoch}/{warmup_epochs} epochs completed. Skipping pruning check."
464+
# )
465+
return
466+
467+
elapsed = time.time() - self._trial_start_time
468+
completed_epochs = max(current_epoch, 1)
469+
average_epoch_time = elapsed / completed_epochs
470+
projected_total_time = average_epoch_time * total_epochs
471+
472+
# Retrieve threshold from study's user attributes.
473+
if self.optuna_trial is not None and hasattr(self.optuna_trial, "study"):
474+
threshold = self.optuna_trial.study.user_attrs.get(
475+
"runtime_threshold", None
476+
)
477+
else:
478+
threshold = None
479+
480+
# print(
481+
# f"[time_pruning] Epoch: {current_epoch}/{total_epochs} | "
482+
# f"Elapsed: {elapsed:.1f}s | Avg per epoch: {average_epoch_time:.1f}s | "
483+
# f"Projected total: {projected_total_time:.1f}s | Threshold: {threshold:.1f}s"
484+
# )
485+
486+
if threshold is not None:
487+
if projected_total_time > threshold:
488+
if self.optuna_trial is not None:
489+
tqdm.write(
490+
f"[time_pruning] Projected total time {projected_total_time:.1f}s exceeds threshold {threshold:.1f}s. Pruning trial."
491+
)
492+
self.optuna_trial.set_user_attr(
493+
"prune_reason",
494+
f"Projected runtime {projected_total_time:.1f}s exceeds threshold {threshold:.1f}s",
495+
)
496+
raise optuna.TrialPruned(
497+
f"Projected total time {projected_total_time:.1f}s exceeds threshold {threshold:.1f}s"
498+
)
499+
433500

434501
SurrogateModel = TypeVar("SurrogateModel", bound=AbstractSurrogateModel)

0 commit comments

Comments
 (0)