Skip to content

Commit 562f6f4

Browse files
committed
Tests for surrogates, datasets, training, eval
1 parent 54a65ff commit 562f6f4

11 files changed

Lines changed: 678 additions & 207 deletions

codes/train/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
train_and_save_model,
55
create_task_list_for_surrogate,
66
worker,
7+
DummyLock,
78
)
89

910
__all__ = [
@@ -12,4 +13,5 @@
1213
"train_and_save_model",
1314
"create_task_list_for_surrogate",
1415
"worker",
16+
"DummyLock",
1517
]

config_full.yaml

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Global settings for the benchmark
22
training_id: "optimizer_test"
3-
surrogates: ["MultiONet", "FullyConnected", "LatentNeuralODE", "LatentPoly"]
4-
batch_size: [65536, 65536, 512, 512]
3+
surrogates: ["MultiONet", "FullyConnected", "LatentNeuralODE", "LatentPoly"]
4+
batch_size: [65536, 65536, 512, 512]
55
epochs: [200, 200, 110, 200] # [20000, 7500, 20000, 15000]
6-
dataset:
6+
dataset:
77
name: "cloud"
88
log10_transform: True
99
log10_transform_params: False
@@ -13,28 +13,28 @@ dataset:
1313
tolerance: 1e-25
1414
subset_factor: 1
1515
log_timesteps: True
16-
devices: ["cuda:0", "cuda:1", "cuda:2", "cuda:3"] # ["cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7", "cuda:8"]
16+
devices: ["cuda:0", "cuda:1", "cuda:2", "cuda:3"] # ["cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7", "cuda:8"]
1717
seed: 42
1818
verbose: False
1919
relative_error_threshold: 1e-10
2020
checkpointing: True
2121

2222
# Models to train
23-
interpolation:
24-
enabled: False
23+
interpolation:
24+
enabled: True
2525
intervals: [2, 3, 4, 5, 6, 7, 8, 9, 10]
2626
extrapolation:
27-
enabled: False
27+
enabled: True
2828
cutoffs: [50, 60, 70, 80, 90]
29-
sparse:
30-
enabled: False
29+
sparse:
30+
enabled: True
3131
factors: [2, 4, 8, 16, 32]
3232
batch_scaling:
33-
enabled: False
33+
enabled: True
3434
sizes: [1/16, 1/8, 1/4, 1/2]
35-
uncertainty:
36-
enabled: False
37-
ensemble_size: 5 # Number of models for deep ensemble
35+
uncertainty:
36+
enabled: True
37+
ensemble_size: 5 # Number of models for deep ensemble
3838

3939
# Evaluations during benchmark
4040
losses: True

test/test_bench_modalities.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import pytest
2+
import numpy as np
3+
import torch
4+
from unittest.mock import patch
5+
6+
from codes.benchmark.bench_fcts import (
7+
evaluate_interpolation,
8+
evaluate_extrapolation,
9+
evaluate_sparse,
10+
evaluate_batchsize,
11+
evaluate_UQ,
12+
)
13+
14+
15+
# Dummy model to record load calls
16+
class DummyModel:
17+
def __init__(self, device, n_quantities, n_timesteps, n_parameters, config):
18+
self._loads = []
19+
20+
def load(self, training_id, surr_name, model_identifier):
21+
self._loads.append(model_identifier)
22+
23+
def predict(self, data_loader):
24+
# targets always zero. Shape (batch=2, timesteps=4, quantities=1).
25+
preds = torch.rand(2, 4, 1)
26+
targets = torch.rand(2, 4, 1)
27+
return preds, targets
28+
29+
30+
# Two standalone fakes: one for heatmap (returns tuple), one for all others (returns None)
31+
def _fake_heatmap(*args, **kwargs):
32+
return ([], [])
33+
34+
35+
def _fake_noop(*args, **kwargs):
36+
return None
37+
38+
39+
@pytest.fixture(autouse=True)
40+
def patch_plots():
41+
import codes.benchmark.bench_fcts as bf
42+
43+
fake_impl = {}
44+
for name in dir(bf):
45+
if not name.startswith("plot_"):
46+
continue
47+
if name == "plot_error_correlation_heatmap":
48+
fake_impl[name] = _fake_heatmap
49+
else:
50+
fake_impl[name] = _fake_noop
51+
52+
with patch.multiple("codes.benchmark.bench_fcts", **fake_impl):
53+
yield
54+
55+
56+
@pytest.mark.parametrize(
57+
"raw_vals, cfg_key, func, main_bs, expected_nums",
58+
[
59+
([2, 3, 5], "interpolation", evaluate_interpolation, None, [1, 2, 3, 5]),
60+
([1, 2, 4], "extrapolation", evaluate_extrapolation, None, [1, 2, 4]),
61+
([2, 4, 8], "sparse", evaluate_sparse, None, [1, 2, 4, 8]),
62+
([0.5, 2], "batch_scaling", evaluate_batchsize, 8, [4, 8, 16]),
63+
(3, "uncertainty", evaluate_UQ, None, [0, 1, 2]),
64+
],
65+
)
66+
def test_modality_variations(raw_vals, cfg_key, func, main_bs, expected_nums):
67+
surr = "TestSurr"
68+
cfg = {"training_id": "TID", "surrogates": [surr]}
69+
if cfg_key == "uncertainty":
70+
cfg["uncertainty"] = {"enabled": True, "ensemble_size": raw_vals}
71+
else:
72+
cfg[cfg_key] = {"enabled": True}
73+
subkey = {
74+
"interpolation": "intervals",
75+
"extrapolation": "cutoffs",
76+
"sparse": "factors",
77+
"batch_scaling": "sizes",
78+
}[cfg_key]
79+
cfg[cfg_key][subkey] = raw_vals
80+
if cfg_key == "batch_scaling":
81+
cfg["batch_size"] = [main_bs]
82+
83+
timesteps = np.arange(4)
84+
loader = object()
85+
labels = ["q"] if func is evaluate_interpolation else None
86+
87+
model = DummyModel(None, 1, len(timesteps), 0, {})
88+
89+
# invoke
90+
if func is evaluate_interpolation:
91+
metrics = func(model, surr, loader, timesteps, cfg, labels)
92+
elif func is evaluate_extrapolation:
93+
metrics = func(model, surr, loader, timesteps, cfg, labels)
94+
elif func is evaluate_sparse:
95+
metrics = func(model, surr, loader, timesteps, n_train_samples=10, conf=cfg)
96+
elif func is evaluate_batchsize:
97+
metrics = func(model, surr, loader, timesteps, cfg)
98+
else:
99+
metrics = func(model, surr, loader, timesteps, cfg, labels=None)
100+
101+
lower = surr.lower()
102+
# build expected identifiers
103+
ids = []
104+
if cfg_key == "interpolation":
105+
for i in expected_nums:
106+
ids.append(f"{lower}_main" if i == 1 else f"{lower}_interpolation_{i}")
107+
elif cfg_key == "extrapolation":
108+
max_c = len(timesteps)
109+
for c in expected_nums:
110+
ids.append(f"{lower}_main" if c == max_c else f"{lower}_extrapolation_{c}")
111+
elif cfg_key == "sparse":
112+
for f in expected_nums:
113+
ids.append(f"{lower}_main" if f == 1 else f"{lower}_sparse_{f}")
114+
elif cfg_key == "batch_scaling":
115+
for bs in expected_nums:
116+
ids.append(f"{lower}_main" if bs == main_bs else f"{lower}_batchsize_{bs}")
117+
else: # uncertainty
118+
for idx in expected_nums:
119+
ids.append(f"{lower}_main" if idx == 0 else f"{lower}_UQ_{idx}")
120+
121+
assert model._loads == ids
122+
123+
prefix = {
124+
"interpolation": "interval",
125+
"extrapolation": "cutoff",
126+
"sparse": "factor",
127+
"batch_scaling": "batch_size",
128+
"uncertainty": None,
129+
}[cfg_key]
130+
131+
if cfg_key != "uncertainty":
132+
for num in expected_nums:
133+
assert f"{prefix} {num}" in metrics
134+
else:
135+
assert "average_uncertainty" in metrics

test/test_data.py

Lines changed: 0 additions & 69 deletions
This file was deleted.

test/test_datasets.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
"""
2-
Comprehensive unit tests for dataset functionality, including data loading,
3-
downloading, and validation of available datasets.
4-
"""
5-
61
import os
72
import tempfile
83

0 commit comments

Comments
 (0)