Skip to content

Commit 25e9494

Browse files
Merge pull request #46 from robin-janssen/improve-unit-tests
Further improvements for unit tests
2 parents 0a6dfc4 + 396247d commit 25e9494

7 files changed

Lines changed: 828 additions & 33 deletions

File tree

AGENT.md

Lines changed: 0 additions & 26 deletions
This file was deleted.

README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
# CODES Benchmark
22

3-
[![codecov](https://codecov.io/github/robin-janssen/CODES-Benchmark/branch/develop/graph/badge.svg?token=TNF9ISCAJK)](https://codecov.io/github/robin-janssen/CODES-Benchmark)
3+
[![codecov](https://codecov.io/github/robin-janssen/CODES-Benchmark/graph/badge.svg?token=TNF9ISCAJK)](https://codecov.io/github/robin-janssen/CODES-Benchmark)
44
![Static Badge](https://img.shields.io/badge/license-GPLv3-blue)
55
![Static Badge](https://img.shields.io/badge/NeurIPS-2024-green)
66

7-
87
🎉 CODES was accepted to the ML4PS workshop @ NeurIPS2024 🎉
98

109
## Benchmarking Coupled ODE Surrogates
1110

12-
CODES is a benchmark for coupled ODE surrogate models.
11+
CODES is a benchmark for coupled ODE surrogate models.
1312

1413
<picture>
1514
<!-- Dark mode SVG -->

codes/benchmark/bench_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,15 @@ def check_benchmark(conf: dict) -> None:
7878

7979
training_conf = read_yaml_config(yaml_file)
8080

81-
# 1. Check Surrogates
81+
# Check Surrogates
8282
training_surrogates = set(training_conf.get("surrogates", []))
8383
benchmark_surrogates = set(conf.get("surrogates", []))
8484
if not benchmark_surrogates.issubset(training_surrogates):
8585
raise ValueError(
8686
"Benchmark configuration includes surrogates that were not in the training configuration."
8787
)
8888

89-
# 2. Check Batch Size
89+
# Check Batch Size
9090
if "batch_size" in conf:
9191
training_batch_size = training_conf.get("batch_size", [])
9292
benchmark_batch_size = conf.get("batch_size", [])
@@ -107,7 +107,7 @@ def check_benchmark(conf: dict) -> None:
107107
print("Exiting...")
108108
exit()
109109

110-
# 3. Check Dataset Settings
110+
# Check Dataset Settings
111111
training_dataset = training_conf.get("dataset", {})
112112
benchmark_dataset = conf.get("dataset", {})
113113

@@ -127,7 +127,7 @@ def check_benchmark(conf: dict) -> None:
127127
f"Additional dataset setting '{key}' found in benchmark configuration that is not present in training configuration."
128128
)
129129

130-
# 4. Check Modalities (Interpolation, Extrapolation, Sparse, Batch Scaling, Uncertainty)
130+
# Check Modalities (Interpolation, Extrapolation, Sparse, Batch Scaling, Uncertainty)
131131
modalities = [
132132
"interpolation",
133133
"extrapolation",

test/test_bench_compare.py

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
import numpy as np
2+
import pytest
3+
import os
4+
import codes.benchmark.bench_fcts as bf
5+
6+
7+
# Helpers to build dummy metrics
8+
def make_metrics_for_main(surr_names, n_timesteps=4, n_quantities=3):
9+
# for compare_main_losses
10+
metrics = {}
11+
for i, name in enumerate(surr_names):
12+
metrics[name] = {
13+
"timesteps": np.zeros(n_timesteps),
14+
"accuracy": {"absolute_errors": np.zeros((1, n_timesteps, n_quantities))},
15+
"n_params": 123 + i,
16+
}
17+
return metrics
18+
19+
20+
def make_metrics_for_relative(surr_names, timesteps):
21+
# build metrics with a relative_errors array
22+
metrics = {}
23+
for name in surr_names:
24+
rel = np.arange(len(timesteps) * 1.0 * 1.0).reshape(1, len(timesteps), 1)
25+
metrics[name] = {
26+
"accuracy": {"relative_errors": rel},
27+
"timesteps": np.array(timesteps),
28+
}
29+
return metrics
30+
31+
32+
def make_metrics_for_timing(surr_names):
33+
metrics = {}
34+
for name in surr_names:
35+
metrics[name] = {
36+
"timing": {
37+
"mean_inference_time_per_run": 1.23,
38+
"std_inference_time_per_run": 0.45,
39+
}
40+
}
41+
return metrics
42+
43+
44+
def make_metrics_for_dynamic(surr_names, n_timesteps=4, n_quantities=2):
45+
metrics = {}
46+
for name in surr_names:
47+
metrics[name] = {
48+
"accuracy": {"absolute_errors": np.zeros((1, n_timesteps, n_quantities))},
49+
"gradients": {
50+
"gradients": np.ones((1, n_timesteps, n_quantities)),
51+
"avg_correlation": 0.5,
52+
"max_gradient": 0.6,
53+
"max_error": 0.7,
54+
"max_counts": 8,
55+
},
56+
}
57+
return metrics
58+
59+
60+
def make_metrics_for_generalization(surr_names):
61+
base = {}
62+
for name in surr_names:
63+
base[name] = {
64+
"interpolation": {"intervals": [1, 2], "model_errors": [0.1, 0.2]},
65+
"extrapolation": {"cutoffs": [1, 3], "model_errors": [0.3, 0.4]},
66+
"sparse": {"n_train_samples": [5, 10], "model_errors": [0.5, 0.6]},
67+
"batch_size": {"batch_sizes": [16, 32], "model_errors": [0.7, 0.8]},
68+
"UQ": {
69+
"pred_uncertainty": np.ones((1, 2, 1)) * 0.9,
70+
"absolute_errors": np.zeros((1, 2, 1)),
71+
"relative_errors": np.zeros((1, 2, 1)),
72+
"axis_max": 2,
73+
"max_counts": 3,
74+
"correlation_metrics": 0.2,
75+
"weighted_diff": np.zeros((1, 2, 1)),
76+
},
77+
}
78+
return base
79+
80+
81+
@pytest.fixture(autouse=True)
82+
def stub_plots_and_io(monkeypatch):
83+
calls = []
84+
for fn in [
85+
# only stub out any pure-plot routines, but leave the CSV writers alone
86+
"inference_time_bar_plot",
87+
"plot_comparative_dynamic_correlation_heatmaps",
88+
"plot_generalization_error_comparison",
89+
"plot_uncertainty_over_time_comparison",
90+
"plot_comparative_error_correlation_heatmaps",
91+
"plot_error_distribution_comparative",
92+
"plot_uncertainty_confidence",
93+
"plot_loss_comparison",
94+
"plot_loss_comparison_equal",
95+
"plot_loss_comparison_train_duration",
96+
"plot_relative_errors",
97+
"plot_error_distribution_comparative",
98+
]:
99+
monkeypatch.setattr(bf, fn, lambda *a, _n=fn, **k: calls.append((_n, a, k)))
100+
return calls
101+
102+
103+
@pytest.fixture
104+
def cfg():
105+
return {
106+
"training_id": "TID",
107+
"devices": ["cpu"],
108+
"losses": True,
109+
"gradients": True,
110+
"timing": True,
111+
"interpolation": {"enabled": True},
112+
"extrapolation": {"enabled": True},
113+
"sparse": {"enabled": True},
114+
"batch_scaling": {"enabled": True},
115+
"uncertainty": {"enabled": True},
116+
"epochs": [5],
117+
"surrogates": ["M1"],
118+
"relative_error_threshold": 0.0,
119+
}
120+
121+
122+
def test_compare_main_losses(stub_plots_and_io, cfg, monkeypatch):
123+
metrics = make_metrics_for_main(["M1"])
124+
125+
# stub get_surrogate -> our fake model class
126+
class Fake:
127+
def __init__(self, device, n_quantities, n_timesteps, n_parameters, config):
128+
self.train_loss = 0.11
129+
self.test_loss = 0.22
130+
self.train_duration = 3.14
131+
132+
def load(self, *args, **kw):
133+
pass
134+
135+
monkeypatch.setattr(bf, "get_surrogate", lambda n: Fake)
136+
monkeypatch.setattr(bf, "get_model_config", lambda *a, **k: {})
137+
bf.compare_main_losses(metrics, cfg)
138+
names = [c[0] for c in stub_plots_and_io]
139+
# expect the three plotting calls, in order
140+
assert names[:3] == [
141+
"plot_loss_comparison",
142+
"plot_loss_comparison_equal",
143+
"plot_loss_comparison_train_duration",
144+
]
145+
146+
147+
def test_compare_relative_errors(stub_plots_and_io, cfg):
148+
timesteps = [0.0, 1.0, 2.0]
149+
metrics = make_metrics_for_relative(["M1"], timesteps)
150+
bf.compare_relative_errors(metrics, cfg)
151+
# mean and median come from np.mean/median over rel errors
152+
mean_err = np.mean(metrics["M1"]["accuracy"]["relative_errors"], axis=(0, 2))
153+
median_err = np.median(metrics["M1"]["accuracy"]["relative_errors"], axis=(0, 2))
154+
# first call to plot_relative_errors
155+
_n, args, kw = stub_plots_and_io[0]
156+
assert _n == "plot_relative_errors"
157+
# args = ( mean_dict, median_dict, timesteps, cfg )
158+
assert pytest.approx(list(args[0].values())[0]) == mean_err
159+
assert pytest.approx(list(args[1].values())[0]) == median_err
160+
assert np.all(args[2] == timesteps)
161+
# second call
162+
assert stub_plots_and_io[1][0] == "plot_error_distribution_comparative"
163+
164+
165+
def test_compare_inference_time(stub_plots_and_io, cfg):
166+
metrics = make_metrics_for_timing(["M1"])
167+
bf.compare_inference_time(metrics, cfg)
168+
169+
name, args, kw = stub_plots_and_io[0]
170+
assert name == "inference_time_bar_plot"
171+
172+
# now args has 5 elements, show_title is in kw
173+
surrogates, means, stds, conf, save_flag = args
174+
assert surrogates == ["M1"]
175+
assert means == [1.23]
176+
assert stds == [0.45]
177+
assert conf is cfg
178+
assert save_flag is True
179+
assert kw.get("show_title", False) is True
180+
181+
182+
def test_compare_dynamic_accuracy(stub_plots_and_io, cfg):
183+
m = make_metrics_for_dynamic(["M1"])
184+
bf.compare_dynamic_accuracy(m, cfg)
185+
186+
name, args, kw = stub_plots_and_io[0]
187+
assert name == "plot_comparative_dynamic_correlation_heatmaps"
188+
189+
# now args has 7 elements, show_title is in kw
190+
grads, abs_errs, corrs, max_grads, max_errs, max_counts, conf = args
191+
assert list(grads.keys()) == ["M1"]
192+
assert corrs["M1"] == 0.5
193+
assert max_counts["M1"] == 8
194+
assert conf is cfg
195+
assert kw.get("show_title", False) is True
196+
197+
198+
def test_compare_UQ_and_confidence(stub_plots_and_io, cfg, monkeypatch):
199+
base = make_metrics_for_generalization(["M1"])
200+
# ADD a dummy timesteps array
201+
base["M1"]["timesteps"] = np.array([0.0, 1.0])
202+
203+
# stub out plot_uncertainty_confidence to return known scores
204+
monkeypatch.setattr(bf, "plot_uncertainty_confidence", lambda *a, **k: {"M1": 0.42})
205+
206+
bf.compare_UQ(base, cfg)
207+
208+
# after compare_UQ, confidence_scores should exist in metrics
209+
assert base["M1"]["UQ"]["confidence_scores"] == 0.42
210+
211+
212+
def test_tabular_comparison_creates_files(tmp_path, stub_plots_and_io, monkeypatch):
213+
metrics = {
214+
"M1": {
215+
"accuracy": {
216+
"mean_squared_error": 0.1,
217+
"mean_absolute_error": 0.2,
218+
"mean_relative_error": 0.3,
219+
"main_model_epochs": 4,
220+
"main_model_training_time": 7.0,
221+
}
222+
},
223+
"M2": {
224+
"accuracy": {
225+
"mean_squared_error": 0.01,
226+
"mean_absolute_error": 0.02,
227+
"mean_relative_error": 0.03,
228+
"main_model_epochs": 5,
229+
"main_model_training_time": 3.0,
230+
}
231+
},
232+
}
233+
234+
cfg = {
235+
"training_id": tmp_path.name,
236+
"timing": False,
237+
"gradients": False,
238+
"compute": False,
239+
"uncertainty": {"enabled": False},
240+
"interpolation": {"enabled": False},
241+
"extrapolation": {"enabled": False},
242+
"sparse": {"enabled": False},
243+
"batch_scaling": {"enabled": False},
244+
"verbose": False,
245+
}
246+
247+
# run inside tmp_path
248+
monkeypatch.chdir(tmp_path)
249+
# manually create results/<study_id>
250+
os.makedirs(tmp_path / "results" / cfg["training_id"], exist_ok=True)
251+
252+
stub_plots_and_io.clear()
253+
bf.tabular_comparison(metrics, cfg)
254+
255+
# ensure the text and CSV files on disk
256+
assert (tmp_path / "results" / cfg["training_id"] / "metrics_table.txt").exists()
257+
assert (tmp_path / "results" / cfg["training_id"] / "metrics_table.csv").exists()
258+
assert (tmp_path / "results" / cfg["training_id"] / "all_metrics.csv").exists()

0 commit comments

Comments
 (0)