Skip to content

Commit 5bb4c4a

Browse files
Merge pull request #39 from robin-janssen/architecture_refactoring
Architecture refactoring
2 parents c3f7228 + 893432d commit 5bb4c4a

37 files changed

Lines changed: 1132 additions & 638 deletions

.gitignore

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ profiling*
1616
study
1717
sindy
1818
lorenzo_data.ipynb
19-
optuna_runs/models
20-
optuna_runs/studies
21-
optuna_runs/plots
19+
tuned/models
20+
tuned/studies
21+
tuned/plots
2222
scripts
2323
build/
2424
dist/
@@ -36,4 +36,4 @@ docs/_build/
3636
.doctrees/
3737
docs/docs/
3838
.coverage
39-
optuna_runs/
39+
tuned/

codes/benchmark/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
plot_dynamic_correlation_heatmap,
3535
plot_error_correlation_heatmap,
3636
plot_error_distribution_comparative,
37-
plot_error_distribution_per_chemical,
37+
plot_error_distribution_per_quantity,
3838
plot_example_mode_predictions,
3939
plot_example_predictions_with_uncertainty,
4040
plot_generalization_error_comparison,
@@ -110,7 +110,7 @@
110110
"plot_uncertainty_vs_errors",
111111
"plot_uncertainty_confidence",
112112
"plot_surr_losses",
113-
"plot_error_distribution_per_chemical",
113+
"plot_error_distribution_per_quantity",
114114
"plot_losses",
115115
"plot_loss_comparison",
116116
"plot_loss_comparison_train_duration",

codes/benchmark/bench_fcts.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
plot_dynamic_correlation_heatmap,
2020
plot_error_correlation_heatmap,
2121
plot_error_distribution_comparative,
22-
plot_error_distribution_per_chemical,
22+
plot_error_distribution_per_quantity,
2323
plot_example_mode_predictions,
2424
plot_example_predictions_with_uncertainty,
2525
plot_generalization_error_comparison,
@@ -86,9 +86,9 @@ def run_benchmark(surr_name: str, surrogate_class, conf: dict) -> dict[str, Any]
8686
)
8787
model_config = get_model_config(surr_name, conf)
8888
n_timesteps = train_data.shape[1]
89-
n_chemicals = train_data.shape[2]
89+
n_quantities = train_data.shape[2]
9090
n_test_samples = n_timesteps * val_data.shape[0]
91-
model = surrogate_class(device, n_chemicals, n_timesteps, model_config)
91+
model = surrogate_class(device, n_quantities, n_timesteps, model_config)
9292

9393
# Placeholder for metrics
9494
metrics = {}
@@ -185,25 +185,25 @@ def evaluate_accuracy(
185185
labels: list | None = None,
186186
) -> dict[str, Any]:
187187
"""
188-
Evaluate the accuracy of the surrogate model.
189-
190-
Args:
191-
model: Instance of the surrogate model class.
192-
surr_name (str): The name of the surrogate model.
193-
timesteps (np.ndarray): The timesteps array.
194-
test_loader (DataLoader): The DataLoader object containing the test data.
195-
conf (dict): The configuration dictionary.
196-
labels (list, optional): The labels for the chemical species.
197-
198-
Returns:
199-
dict: A dictionary containing accuracy metrics.
188+
Evaluate the accuracy of the surrogate model.
189+
quantitiesquantities
190+
Args:
191+
model: Instance of the surrogate model class.
192+
surr_name (str): The name of the surrogate model.
193+
timesteps (np.ndarray): The timesteps array.
194+
test_loader (DataLoader): The DataLoader object containing the test data.
195+
conf (dict): The configuration dictionary.
196+
labels (list, optional): The labels for the quantities.
197+
198+
Returns:
199+
dict: A dictionary containing accuracy metrics.
200200
"""
201201
training_id = conf["training_id"]
202202

203203
# Load the model
204204
model.load(training_id, surr_name, model_identifier=f"{surr_name.lower()}_main")
205205
train_time = model.train_duration
206-
num_chemicals = model.n_chemicals
206+
num_quantities = model.n_quantities
207207
model_index = conf["surrogates"].index(surr_name)
208208
n_epochs = conf["epochs"][model_index]
209209

@@ -229,12 +229,12 @@ def evaluate_accuracy(
229229
show_title=TITLE,
230230
)
231231

232-
plot_error_distribution_per_chemical(
232+
plot_error_distribution_per_quantity(
233233
surr_name,
234234
conf,
235235
relative_errors,
236-
chemical_names=labels,
237-
num_chemicals=num_chemicals,
236+
quantity_names=labels,
237+
num_quantities=num_quantities,
238238
save=True,
239239
show_title=TITLE,
240240
)
@@ -431,7 +431,7 @@ def evaluate_interpolation(
431431
test_loader (DataLoader): The DataLoader object containing the test data.
432432
timesteps (np.ndarray): The timesteps array.
433433
conf (dict): The configuration dictionary.
434-
labels (list, optional): The labels for the chemical species.
434+
labels (list, optional): The labels for the quantities.
435435
436436
Returns:
437437
dict: A dictionary containing interpolation metrics.
@@ -529,7 +529,7 @@ def evaluate_extrapolation(
529529
test_loader (DataLoader): The DataLoader object containing the test data.
530530
timesteps (np.ndarray): The timesteps array.
531531
conf (dict): The configuration dictionary.
532-
labels (list, optional): The labels for the chemical species.
532+
labels (list, optional): The labels for the quantities.
533533
534534
Returns:
535535
dict: A dictionary containing extrapolation metrics.
@@ -793,7 +793,7 @@ def evaluate_UQ(
793793
test_loader (DataLoader): The DataLoader object containing the test data.
794794
timesteps (np.ndarray): The timesteps array.
795795
conf (dict): The configuration dictionary.
796-
labels (list, optional): The labels for the chemical species.
796+
labels (list, optional): The labels for the quantities.
797797
798798
Returns:
799799
dict: A dictionary containing UQ metrics.
@@ -952,9 +952,9 @@ def compare_main_losses(metrics: dict, config: dict) -> None:
952952
training_id = config["training_id"]
953953
surrogate_class = get_surrogate(surr_name)
954954
n_timesteps = metrics[surr_name]["timesteps"].shape[0]
955-
n_chemicals = metrics[surr_name]["accuracy"]["absolute_errors"].shape[2]
955+
n_quantities = metrics[surr_name]["accuracy"]["absolute_errors"].shape[2]
956956
model_config = get_model_config(surr_name, config)
957-
model = surrogate_class(device, n_chemicals, n_timesteps, model_config)
957+
model = surrogate_class(device, n_quantities, n_timesteps, model_config)
958958

959959
def load_losses(model_identifier: str):
960960
model.load(training_id, surr_name, model_identifier=model_identifier)
@@ -1001,9 +1001,9 @@ def load_losses(model_identifier: str):
10011001
# training_id = config["training_id"]
10021002
# surrogate_class = get_surrogate(surr_name)
10031003
# n_timesteps = metrics[surr_name]["timesteps"].shape[0]
1004-
# n_chemicals = metrics[surr_name]["accuracy"]["absolute_errors"].shape[2]
1004+
# n_quantities = metrics[surr_name]["accuracy"]["absolute_errors"].shape[2]
10051005
# model_config = get_model_config(surr_name, config)
1006-
# model = surrogate_class(device, n_chemicals, n_timesteps, model_config)
1006+
# model = surrogate_class(device, n_quantities, n_timesteps, model_config)
10071007
# model_identifier = f"{surr_name.lower()}_main"
10081008
# model.load(training_id, surr_name, model_identifier=model_identifier)
10091009
# MAE.append(model.MAE)

0 commit comments

Comments
 (0)