Skip to content

Commit b5d8ab9

Browse files
committed
Merge branch 'main' into architecture_refactoring
2 parents 7208b96 + c3f7228 commit b5d8ab9

14 files changed

Lines changed: 133 additions & 52 deletions

File tree

codes/benchmark/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
read_yaml_config,
7474
save_table_csv,
7575
write_metrics_to_yaml,
76+
measure_inference_time,
7677
)
7778

7879
__all__ = [
@@ -146,4 +147,5 @@
146147
"make_comparison_csv",
147148
"save_table_csv",
148149
"get_model_config",
150+
"measure_inference_time",
149151
]

codes/benchmark/bench_fcts.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import time
21
from contextlib import redirect_stdout
32
from typing import Any
43

@@ -40,6 +39,7 @@
4039
get_model_config,
4140
get_surrogate,
4241
make_comparison_csv,
42+
measure_inference_time,
4343
measure_memory_footprint,
4444
save_table_csv,
4545
write_metrics_to_yaml,
@@ -349,13 +349,12 @@ def time_inference(
349349
n_runs: int = 5,
350350
) -> dict[str, Any]:
351351
"""
352-
Time the inference of the surrogate model.
352+
Time the inference of the surrogate model (full version with metrics).
353353
354354
Args:
355355
model: Instance of the surrogate model class.
356356
surr_name (str): The name of the surrogate model.
357357
test_loader (DataLoader): The DataLoader object containing the test data.
358-
timesteps (np.ndarray): The timesteps array.
359358
conf (dict): The configuration dictionary.
360359
n_test_samples (int): The number of test samples.
361360
n_runs (int, optional): Number of times to run the inference for timing.
@@ -366,35 +365,19 @@ def time_inference(
366365
training_id = conf["training_id"]
367366
model.load(training_id, surr_name, model_identifier=f"{surr_name.lower()}_main")
368367

369-
# Run inference multiple times and record the durations
370-
inference_times = []
371-
for _ in range(n_runs):
372-
# _, _ = model.predict(data_loader=test_loader)
373-
total_time = 0
374-
with torch.inference_mode():
375-
for inputs in test_loader:
376-
start_time = time.perf_counter()
377-
_, _ = model.forward(inputs)
378-
end_time = time.perf_counter()
379-
total_time += end_time - start_time
380-
# total_time /= n_test_samples
381-
inference_times.append(total_time)
382-
383-
# Calculate metrics
368+
inference_times = measure_inference_time(model, test_loader, n_runs=n_runs)
369+
384370
mean_inference_time = np.mean(inference_times)
385371
std_inference_time = np.std(inference_times)
386372

387-
# Store metrics
388-
timing_metrics = {
373+
return {
389374
"mean_inference_time_per_run": mean_inference_time,
390375
"std_inference_time_per_run": std_inference_time,
391376
"num_predictions": n_test_samples,
392377
"mean_inference_time_per_prediction": mean_inference_time / n_test_samples,
393378
"std_inference_time_per_prediction": std_inference_time / n_test_samples,
394379
}
395380

396-
return timing_metrics
397-
398381

399382
def evaluate_compute(
400383
model, surr_name: str, test_loader: DataLoader, conf: dict

codes/benchmark/bench_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@
88
import numpy as np
99
import torch
1010
import yaml
11+
from torch.utils.data import DataLoader
1112

1213
from codes.surrogates import SurrogateModel, surrogate_classes
1314
from codes.utils import read_yaml_config
1415

16+
import time
17+
1518

1619
def check_surrogate(surrogate: str, conf: dict) -> None:
1720
"""
@@ -699,3 +702,32 @@ def get_model_config(surr_name: str, config: dict) -> dict:
699702
model_config = {}
700703

701704
return model_config
705+
706+
707+
def measure_inference_time(
708+
model,
709+
test_loader: DataLoader,
710+
n_runs: int = 5,
711+
) -> list[float]:
712+
"""
713+
Measure total inference time over a DataLoader across multiple runs.
714+
715+
Args:
716+
model: Model instance with a `.forward()` method.
717+
test_loader (DataLoader): Loader with test data.
718+
n_runs (int): Number of repeated runs for averaging.
719+
720+
Returns:
721+
list[float]: List of total inference times per run (in seconds).
722+
"""
723+
inference_times = []
724+
for _ in range(n_runs):
725+
total_time = 0
726+
with torch.inference_mode():
727+
for inputs in test_loader:
728+
start_time = time.perf_counter()
729+
_, _ = model.forward(inputs)
730+
end_time = time.perf_counter()
731+
total_time += end_time - start_time
732+
inference_times.append(total_time)
733+
return inference_times

codes/surrogates/DeepONet/deeponet.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ def fit(
286286
epochs: int,
287287
position: int = 0,
288288
description: str = "Training DeepONet",
289+
multi_objective: bool = False,
289290
) -> None:
290291
"""
291292
Train the MultiONet model.
@@ -296,6 +297,8 @@ def fit(
296297
epochs (int, optional): The number of epochs to train the model.
297298
position (int): The position of the progress bar.
298299
description (str): The description for the progress bar.
300+
multi_objective (bool): Whether multi-objective optimization is used.
301+
If True, trial.report is not used (not supported by Optuna).
299302
300303
Returns:
301304
None. The training loss, test loss, and MAE are stored in the model.
@@ -335,7 +338,8 @@ def fit(
335338
progress_bar.set_postfix(postfix)
336339

337340
# Report the loss to Optuna and check for pruning
338-
if self.optuna_trial is not None:
341+
if self.optuna_trial is not None and not multi_objective:
342+
339343
self.optuna_trial.report(test_losses[index], epoch)
340344
if self.optuna_trial.should_prune():
341345
raise optuna.TrialPruned()

codes/surrogates/FCNN/fcnn.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,23 @@ def fit(
159159
epochs: int,
160160
position: int = 0,
161161
description: str = "Training FullyConnected",
162+
multi_objective: bool = False,
162163
) -> None:
164+
"""
165+
Train the FullyConnected model.
166+
167+
Args:
168+
train_loader (DataLoader): The DataLoader object containing the training data.
169+
test_loader (DataLoader): The DataLoader object containing the test data.
170+
epochs (int, optional): The number of epochs to train the model.
171+
position (int): The position of the progress bar.
172+
description (str): The description for the progress bar.
173+
multi_objective (bool): Whether multi-objective optimization is used.
174+
If True, trial.report is not used (not supported by Optuna).
175+
176+
Returns:
177+
None. The training loss, test loss, and MAE are stored in the model.
178+
"""
163179
self.n_train_samples = int(len(train_loader.dataset) / self.n_timesteps)
164180
# criterion = nn.MSELoss(reduction="sum")
165181
criterion = nn.MSELoss()
@@ -196,7 +212,7 @@ def fit(
196212
progress_bar.set_postfix(postfix)
197213

198214
# Report the test loss to Optuna
199-
if self.optuna_trial is not None:
215+
if self.optuna_trial is not None and not multi_objective:
200216
self.optuna_trial.report(test_losses[index], step=epoch)
201217
if self.optuna_trial.should_prune():
202218
raise optuna.TrialPruned()

codes/surrogates/LatentNeuralODE/latent_neural_ode.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def fit(
145145
epochs: int,
146146
position: int = 0,
147147
description: str = "Training LatentNeuralODE",
148+
multi_objective: bool = False,
148149
) -> None:
149150
"""
150151
Fits the model to the training data. Sets the train_loss and test_loss attributes.
@@ -156,6 +157,9 @@ def fit(
156157
epochs (int | None): The number of epochs to train the model. If None, uses the value from the config.
157158
position (int): The position of the progress bar.
158159
description (str): The description for the progress bar.
160+
multi_objective (bool): Whether multi-objective optimization is used.
161+
If True, trial.report is not used (not supported by Optuna).
162+
159163
"""
160164
# optimizer = Adam(self.model.parameters(), lr=self.config.learning_rate)
161165
optimizer = AdamWScheduleFree(
@@ -211,7 +215,7 @@ def fit(
211215
progress_bar.set_postfix(postfix)
212216

213217
# Report loss to Optuna and prune if necessary
214-
if self.optuna_trial is not None:
218+
if self.optuna_trial is not None and not multi_objective:
215219
self.optuna_trial.report(test_losses[index], step=epoch)
216220
if self.optuna_trial.should_prune():
217221
raise optuna.TrialPruned()

codes/surrogates/LatentPolynomial/latent_poly.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,19 @@ def fit(
144144
epochs: int,
145145
position: int = 0,
146146
description: str = "Training LatentPoly",
147+
multi_objective: bool = False,
147148
) -> None:
148149
"""
149150
Fit the model to the training data.
150151
151152
Args:
152-
train_loader (DataLoader): Training data loader.
153-
test_loader (DataLoader): Test data loader.
154-
epochs (int): Number of training epochs.
155-
position (int): Progress bar position.
156-
description (str): Description for the progress bar.
153+
train_loader (DataLoader): The data loader for the training data.
154+
test_loader (DataLoader): The data loader for the test data.
155+
epochs (int | None): The number of epochs to train the model. If None, uses the value from the config.
156+
position (int): The position of the progress bar.
157+
description (str): The description for the progress bar.
158+
multi_objective (bool): Whether multi-objective optimization is used.
159+
If True, trial.report is not used (not supported by Optuna).
157160
"""
158161
optimizer = AdamWScheduleFree(
159162
self.model.parameters(), lr=self.config.learning_rate
@@ -199,7 +202,8 @@ def fit(
199202
}
200203
)
201204

202-
if self.optuna_trial is not None:
205+
# Report loss to Optuna and prune if necessary
206+
if self.optuna_trial is not None and not multi_objective:
203207
self.optuna_trial.report(test_losses[index], step=epoch)
204208
if self.optuna_trial.should_prune():
205209
raise optuna.TrialPruned()

codes/tune/optuna_fcts.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,17 @@
22
import queue
33
from distutils.util import strtobool
44

5+
import numpy as np
56
import optuna
67
import torch
78
import torch.nn as nn
89
import yaml
910

10-
from codes.benchmark.bench_utils import get_model_config, get_surrogate
11+
from codes.benchmark.bench_utils import (
12+
get_model_config,
13+
get_surrogate,
14+
measure_inference_time,
15+
)
1116
from codes.utils import check_and_load_data, make_description, set_random_seeds
1217
from codes.utils.data_utils import get_data_subset
1318

@@ -132,9 +137,10 @@ def objective(trial):
132137

133138
def training_run(
134139
trial: optuna.Trial, device: str, config: dict, study_name: str
135-
) -> float:
140+
) -> float | tuple[float, float]:
136141
"""
137142
Run the training for a single Optuna trial and return the loss.
143+
In multi-objective mode, also returns the mean inference time.
138144
139145
Args:
140146
trial (optuna.Trial): Optuna trial object.
@@ -143,9 +149,11 @@ def training_run(
143149
study_name (str): Name of the study.
144150
145151
Returns:
146-
float: Loss value.
152+
float: Loss value in single objective mode.
153+
tuple[float, float]: (loss, mean_inference_time) in multi objective mode.
147154
"""
148155

156+
download_data(config["dataset"]["name"], verbose=False)
149157
train_data, test_data, val_data, timesteps, _, data_params, _ = check_and_load_data(
150158
config["dataset"]["name"],
151159
verbose=False,
@@ -195,6 +203,7 @@ def training_run(
195203
epochs=config["epochs"],
196204
position=pos,
197205
description=description,
206+
multi_objective=config["multi_objective"],
198207
)
199208

200209
criterion = torch.nn.MSELoss()
@@ -210,4 +219,11 @@ def training_run(
210219
base_dir="",
211220
training_id=savepath,
212221
)
213-
return loss
222+
223+
# Check if we're running multi-objective optimisation
224+
if config["multi_objective"]:
225+
# Measure inference time
226+
inference_times = measure_inference_time(model, test_loader)
227+
return loss, np.mean(inference_times)
228+
else:
229+
return loss

codes/utils/data_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,20 +483,22 @@ def update_to(self, b=1, bsize=1, tsize=None):
483483
self.update(b * bsize - self.n)
484484

485485

486-
def download_data(dataset_name: str, path: str | None = None):
486+
def download_data(dataset_name: str, path: str | None = None, verbose: bool = True):
487487
"""
488488
Download the specified dataset if it is not present, with a progress bar.
489489
Args:
490490
dataset_name (str): The name of the dataset.
491491
path (str, optional): The path to save the dataset. If None, the default data directory is used.
492+
verbose (bool): Whether to print information about the download progress.
492493
"""
493494
data_path = (
494495
os.path.abspath(f"datasets/{dataset_name.lower()}/data.hdf5")
495496
if path is None
496497
else os.path.abspath(path)
497498
)
498499
if os.path.isfile(data_path):
499-
print(f"Dataset '{dataset_name}' already downloaded at {data_path}.")
500+
if verbose:
501+
print(f"Dataset '{dataset_name}' already exists at {data_path}.")
500502
return
501503

502504
with open("datasets/data_sources.yaml", "r", encoding="utf-8") as file:
225 KB
Loading

0 commit comments

Comments
 (0)