Skip to content

Commit 1768e8d

Browse files
committed
cleanup, tests for model comparison
1 parent 6de4e76 commit 1768e8d

5 files changed

Lines changed: 173 additions & 97 deletions

File tree

codes/benchmark/bench_fcts.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,37 +1019,6 @@ def load_losses(model_identifier: str):
10191019
)
10201020

10211021

1022-
# def compare_MAE(metrics: dict, config: dict) -> None:
1023-
# """
1024-
# Compare the MAE of different surrogate models over the course of training.
1025-
1026-
# Args:
1027-
# metrics (dict): dictionary containing the benchmark metrics for each surrogate model.
1028-
# config (dict): Configuration dictionary.
1029-
1030-
# Returns:
1031-
# None
1032-
# """
1033-
# MAE = []
1034-
# labels = []
1035-
# train_durations = []
1036-
# device = config["devices"]
1037-
# device = device[0] if isinstance(device, list) else device
1038-
1039-
# for surr_name, _ in metrics.items():
1040-
# training_id = config["training_id"]
1041-
# surrogate_class = get_surrogate(surr_name)
1042-
# n_timesteps = metrics[surr_name]["timesteps"].shape[0]
1043-
# n_quantities = metrics[surr_name]["accuracy"]["absolute_errors"].shape[2]
1044-
# model_config = get_model_config(surr_name, config)
1045-
# model = surrogate_class(device, n_quantities, n_timesteps, model_config)
1046-
# model_identifier = f"{surr_name.lower()}_main"
1047-
# model.load(training_id, surr_name, model_identifier=model_identifier)
1048-
# MAE.append(model.MAE)
1049-
# labels.append(surr_name)
1050-
# train_durations.append(model.train_duration)
1051-
1052-
10531022
def compare_relative_errors(metrics: dict[str, dict], config: dict) -> None:
10541023
"""
10551024
Compare the relative errors over time for different surrogate models.

codes/tune/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,15 @@
1111
maybe_set_runtime_threshold,
1212
training_run,
1313
)
14-
from .postgres_fcts import _make_db_url, initialize_optuna_database
14+
from .postgres_fcts import (
15+
_make_db_url,
16+
initialize_optuna_database,
17+
_check_postgres_running_local,
18+
_start_postgres_server_local,
19+
_check_remote_reachable,
20+
_initialize_postgres_local,
21+
_initialize_postgres_remote,
22+
)
1523
from .tune_utils import (
1624
build_study_names,
1725
copy_config,
@@ -37,4 +45,9 @@
3745
"yes_no",
3846
"_make_db_url",
3947
"initialize_optuna_database",
48+
"_check_postgres_running_local",
49+
"_start_postgres_server_local",
50+
"_check_remote_reachable",
51+
"_initialize_postgres_local",
52+
"_initialize_postgres_remote",
4053
]

codes/utils/data_utils.py

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -223,70 +223,6 @@ def check_and_load_data(
223223
)
224224

225225

226-
def normalize_data_old(
227-
train_data: np.ndarray,
228-
test_data: np.ndarray | None = None,
229-
val_data: np.ndarray | None = None,
230-
mode: str = "standardise",
231-
) -> tuple:
232-
"""
233-
Normalize the data based on the training data statistics.
234-
235-
Args:
236-
train_data (np.ndarray): Training data array.
237-
test_data (np.ndarray, optional): Test data array.
238-
val_data (np.ndarray, optional): Validation data array.
239-
mode (str): Normalization mode, either "minmax" or "standardise".
240-
241-
Returns:
242-
tuple: Normalized training data, test data, and validation data.
243-
"""
244-
if mode not in ["minmax", "standardise"]:
245-
raise ValueError("Mode must be either 'minmax' or 'standardise'")
246-
247-
if mode == "minmax":
248-
# Compute min and max on the training data
249-
data_min = np.min(train_data)
250-
data_max = np.max(train_data)
251-
252-
data_info = {"min": float(data_min), "max": float(data_max), "mode": mode}
253-
254-
# Normalize the training data
255-
train_data_norm = 2 * (train_data - data_min) / (data_max - data_min) - 1
256-
257-
if test_data is not None:
258-
test_data_norm = 2 * (test_data - data_min) / (data_max - data_min) - 1
259-
else:
260-
test_data_norm = None
261-
262-
if val_data is not None:
263-
val_data_norm = 2 * (val_data - data_min) / (data_max - data_min) - 1
264-
else:
265-
val_data_norm = None
266-
267-
elif mode == "standardise":
268-
# Compute mean and std on the training data
269-
mean = np.mean(train_data)
270-
std = np.std(train_data)
271-
272-
data_info = {"mean": float(mean), "std": float(std), "mode": mode}
273-
274-
# Standardize the training data
275-
train_data_norm = (train_data - mean) / std
276-
277-
if test_data is not None:
278-
test_data_norm = (test_data - mean) / std
279-
else:
280-
test_data_norm = None
281-
282-
if val_data is not None:
283-
val_data_norm = (val_data - mean) / std
284-
else:
285-
val_data_norm = None
286-
287-
return data_info, train_data_norm, test_data_norm, val_data_norm
288-
289-
290226
def normalize_data(
291227
train_data: np.ndarray,
292228
test_data: np.ndarray | None = None,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ line_length = 88
6565

6666
[tool.coverage.run]
6767
source = ["codes"]
68-
omit = ["*/tests/*", "*/test_*", "*/bench_plots.py", "*/__init__.py"]
68+
omit = ["*/tests/*", "*/test_*", "*/bench_plots.py", "*/__init__.py", "*/evaluate_study.py", "*/evaluate_tuning.py"]
6969

7070
[tool.coverage.report]
7171
exclude_lines = [

test/test_model_comparison.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# test/test_compare_models.py
2+
import pytest
3+
from codes.benchmark import bench_fcts
4+
5+
6+
@pytest.fixture(autouse=True)
7+
def record_calls(monkeypatch):
8+
"""
9+
Stub out all compare_* and plot_* functions so that calls
10+
just record their names into a shared list, instead of doing any real work.
11+
"""
12+
calls = []
13+
names = [
14+
"compare_relative_errors",
15+
"compare_main_losses",
16+
"compare_dynamic_accuracy",
17+
"compare_inference_time",
18+
"compare_interpolation",
19+
"compare_extrapolation",
20+
"compare_sparse",
21+
"plot_all_generalization_errors",
22+
"compare_batchsize",
23+
"compare_UQ",
24+
"tabular_comparison",
25+
]
26+
for name in names:
27+
monkeypatch.setattr(
28+
bench_fcts,
29+
name,
30+
lambda *args, _n=name, **kw: calls.append(_n),
31+
)
32+
return calls
33+
34+
35+
def make_dummy_metrics():
36+
"""
37+
Build a minimal metrics dict that contains the keys
38+
your compare_models dispatcher will look up.
39+
Values themselves are never inspected by our stubs.
40+
"""
41+
return {
42+
"M1": {
43+
"accuracy": {"relative_errors": None},
44+
"timesteps": None,
45+
"n_params": 0,
46+
# for each enabled branch add a dummy sub-dict:
47+
"timing": {
48+
"mean_inference_time_per_run": 1.0,
49+
"std_inference_time_per_run": 0.1,
50+
},
51+
"gradients": {
52+
"gradients": None,
53+
"avg_correlation": 0.0,
54+
"max_gradient": 0,
55+
"max_error": 0,
56+
"max_counts": 0,
57+
},
58+
"interpolation": {"intervals": [1], "model_errors": [0]},
59+
"extrapolation": {"cutoffs": [1], "model_errors": [0]},
60+
"sparse": {"n_train_samples": [10], "model_errors": [0]},
61+
"batch_size": {"batch_sizes": [32], "model_errors": [0]},
62+
"UQ": {
63+
"pred_uncertainty": None,
64+
"absolute_errors": None,
65+
"relative_errors": None,
66+
"axis_max": None,
67+
"max_counts": None,
68+
"correlation_metrics": None,
69+
"weighted_diff": None,
70+
},
71+
}
72+
}
73+
74+
75+
@pytest.mark.parametrize(
76+
"flags, expected_sequence",
77+
[
78+
# all branches on
79+
(
80+
{
81+
"losses": True,
82+
"gradients": True,
83+
"timing": True,
84+
"interpolation": {"enabled": True},
85+
"extrapolation": {"enabled": True},
86+
"sparse": {"enabled": True},
87+
"batch_scaling": {"enabled": True},
88+
"uncertainty": {"enabled": True},
89+
},
90+
[
91+
"compare_relative_errors",
92+
"compare_main_losses",
93+
"compare_dynamic_accuracy",
94+
"compare_inference_time",
95+
"compare_interpolation",
96+
"compare_extrapolation",
97+
"compare_sparse",
98+
"plot_all_generalization_errors", # only if int+ext+sparse all enabled
99+
"compare_batchsize",
100+
"compare_UQ",
101+
"tabular_comparison",
102+
],
103+
),
104+
# only the mandatory relative-errors + table
105+
(
106+
{
107+
"losses": False,
108+
"gradients": False,
109+
"timing": False,
110+
"interpolation": {"enabled": False},
111+
"extrapolation": {"enabled": False},
112+
"sparse": {"enabled": False},
113+
"batch_scaling": {"enabled": False},
114+
"uncertainty": {"enabled": False},
115+
},
116+
[
117+
"compare_relative_errors",
118+
"tabular_comparison",
119+
],
120+
),
121+
# losses but nothing else
122+
(
123+
{
124+
"losses": True,
125+
"gradients": False,
126+
"timing": False,
127+
"interpolation": {"enabled": False},
128+
"extrapolation": {"enabled": False},
129+
"sparse": {"enabled": False},
130+
"batch_scaling": {"enabled": False},
131+
"uncertainty": {"enabled": False},
132+
},
133+
[
134+
"compare_relative_errors",
135+
"compare_main_losses",
136+
"tabular_comparison",
137+
],
138+
),
139+
],
140+
)
141+
def test_compare_models_branching(record_calls, flags, expected_sequence):
142+
cfg = {
143+
"training_id": "test",
144+
"devices": ["cpu"], # for compare_main_losses
145+
"losses": flags["losses"],
146+
"gradients": flags["gradients"],
147+
"timing": flags["timing"],
148+
"interpolation": flags["interpolation"],
149+
"extrapolation": flags["extrapolation"],
150+
"sparse": flags["sparse"],
151+
"batch_scaling": flags["batch_scaling"],
152+
"uncertainty": flags["uncertainty"],
153+
}
154+
metrics = make_dummy_metrics()
155+
156+
bench_fcts.compare_models(metrics, cfg)
157+
158+
assert record_calls == expected_sequence

0 commit comments

Comments
 (0)