Skip to content

Commit 3b17a92

Browse files
committed
tuning tests
1 parent 562f6f4 commit 3b17a92

3 files changed

Lines changed: 336 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ line_length = 88
6565

6666
[tool.coverage.run]
6767
source = ["codes"]
68-
omit = ["*/tests/*", "*/test_*"]
68+
omit = ["*/tests/*", "*/test_*", "*/bench_plots.py", "*/__init__.py"]
6969

7070
[tool.coverage.report]
7171
exclude_lines = [

test/test_tuning_pipeline.py

Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
import queue
2+
import math
3+
from datetime import datetime, timedelta
4+
5+
import pytest
6+
from optuna.trial import TrialState
7+
8+
from codes.tune.optuna_fcts import (
9+
make_optuna_params,
10+
maybe_set_runtime_threshold,
11+
create_objective,
12+
MODULE_REGISTRY,
13+
)
14+
15+
16+
class DummyTrial:
17+
def __init__(self):
18+
self.suggested = {}
19+
20+
def suggest_int(self, name, low, high, step=1):
21+
# always return low
22+
self.suggested[name] = low
23+
return low
24+
25+
def suggest_float(self, name, low, high, log=False):
26+
# always return high
27+
self.suggested[name] = high
28+
return high
29+
30+
def suggest_categorical(self, name, choices):
31+
# return first choice
32+
val = choices[0]
33+
self.suggested[name] = val
34+
return val
35+
36+
37+
class DummyModel:
38+
def __init__(
39+
self, device, n_quantities, n_timesteps, n_parameters, config=None, **kwargs
40+
):
41+
pass
42+
43+
def to(self, device):
44+
pass
45+
46+
def prepare_data(self, **kwargs):
47+
return "train_loader", "test_loader", None
48+
49+
def fit(self, **kwargs):
50+
pass
51+
52+
def predict(self, loader, leave_log=False):
53+
import torch
54+
55+
t = torch.zeros((2, 4, 1))
56+
return t, t
57+
58+
def save(self, **kwargs):
59+
pass
60+
61+
62+
@pytest.fixture
63+
def basic_params():
64+
return {
65+
"batch_size": {"type": "int", "low": 10, "high": 20, "step": 5},
66+
"learning_rate": {"type": "float", "low": 0.001, "high": 0.01, "log": True},
67+
"activation": {"choices": ["relu", "tanh", "identity"]},
68+
}
69+
70+
71+
@pytest.fixture
72+
def conditional_params(basic_params):
73+
p = basic_params.copy()
74+
p["scheduler"] = {"choices": ["poly", "cosine"]}
75+
p["poly_power"] = {"type": "int", "low": 1, "high": 3}
76+
p["eta_min"] = {"type": "float", "low": 0.0, "high": 0.1}
77+
return p
78+
79+
80+
def test_make_optuna_params_basic(basic_params):
81+
trial = DummyTrial()
82+
# only basic choices
83+
out = make_optuna_params(
84+
trial,
85+
{
86+
"batch_size": basic_params["batch_size"],
87+
"learning_rate": basic_params["learning_rate"],
88+
"activation": basic_params["activation"],
89+
},
90+
)
91+
# int param should equal low
92+
assert out["batch_size"] == basic_params["batch_size"]["low"]
93+
# float param should equal high
94+
assert math.isclose(out["learning_rate"], basic_params["learning_rate"]["high"])
95+
# activation returns first choice
96+
expected_cls = MODULE_REGISTRY[basic_params["activation"]["choices"][0]]
97+
assert isinstance(out["activation"], expected_cls)
98+
99+
100+
def test_make_optuna_params_conditional(conditional_params):
101+
trial = DummyTrial()
102+
# include scheduler to trigger poly branch
103+
params = conditional_params.copy()
104+
params["scheduler"]["choices"] = ["poly"]
105+
out = make_optuna_params(trial, params)
106+
# since scheduler chosen 'poly', poly_power must be suggested
107+
assert "poly_power" in out
108+
# cos branch
109+
trial2 = DummyTrial()
110+
p2 = conditional_params.copy()
111+
p2["scheduler"]["choices"] = ["cosine"]
112+
out2 = make_optuna_params(trial2, p2)
113+
assert "eta_min" in out2
114+
115+
116+
# --------- Tests for maybe_set_runtime_threshold ----------
117+
class FakeTrial:
118+
def __init__(self, num, state, start, complete=None):
119+
self.number = num
120+
self.state = state
121+
self.datetime_start = start
122+
self.datetime_complete = complete
123+
124+
125+
class FakeStudy:
126+
def __init__(self, trials):
127+
self._trials = trials
128+
self.user_attrs = {}
129+
130+
def get_trials(self, deepcopy=False):
131+
return self._trials
132+
133+
def set_user_attr(self, key, val):
134+
self.user_attrs[key] = val
135+
136+
137+
def test_maybe_set_runtime_threshold_not_enough():
138+
# only 1 complete trial, warmup_target=2
139+
t1 = FakeTrial(
140+
0,
141+
TrialState.COMPLETE,
142+
datetime.utcnow() - timedelta(seconds=5),
143+
datetime.utcnow(),
144+
)
145+
study = FakeStudy([t1])
146+
maybe_set_runtime_threshold(study, warmup_target=2)
147+
assert "runtime_threshold" not in study.user_attrs
148+
149+
150+
def test_maybe_set_runtime_threshold_enough():
151+
now = datetime.utcnow()
152+
trials = []
153+
for i in range(3):
154+
trials.append(
155+
FakeTrial(
156+
i,
157+
TrialState.COMPLETE,
158+
now - timedelta(seconds=10 + i),
159+
now - timedelta(seconds=i),
160+
)
161+
)
162+
study = FakeStudy(trials)
163+
maybe_set_runtime_threshold(study, warmup_target=3)
164+
# after enough trials, attrs should be set
165+
assert "runtime_threshold" in study.user_attrs
166+
assert "warmup_mean" in study.user_attrs
167+
assert study.user_attrs["warmup_target"] == 3
168+
169+
170+
# --------- Tests for create_objective ----------
171+
172+
173+
def test_training_run_single_objective(monkeypatch, tmp_path):
174+
# Prepare dummy config
175+
config = {
176+
"dataset": {
177+
"name": "ds",
178+
"log10_transform": False,
179+
"normalise": "none",
180+
"tolerance": 1e-3,
181+
},
182+
"seed": 123,
183+
"surrogate": {"name": "Surr"},
184+
"optuna_params": {},
185+
"batch_size": 16,
186+
"epochs": 5,
187+
"target_percentile": 0.5,
188+
"multi_objective": False,
189+
}
190+
# Fake download
191+
monkeypatch.setattr("codes.tune.optuna_fcts.download_data", lambda *args, **k: None)
192+
# Fake data loaders
193+
import numpy as np
194+
195+
dummy_data = np.zeros((2, 4, 1))
196+
dummy_params = np.zeros((2, 3))
197+
dummy_timesteps = np.arange(4)
198+
dummy_info = {}
199+
monkeypatch.setattr(
200+
"codes.tune.optuna_fcts.check_and_load_data",
201+
lambda *args, **kw: (
202+
(dummy_data, dummy_data, None),
203+
(dummy_params, dummy_params, None),
204+
dummy_timesteps,
205+
None,
206+
dummy_info,
207+
None,
208+
),
209+
)
210+
monkeypatch.setattr(
211+
"codes.tune.optuna_fcts.get_data_subset",
212+
lambda *args, **kw: (
213+
(dummy_data, dummy_data),
214+
(dummy_params, dummy_params),
215+
dummy_timesteps,
216+
),
217+
)
218+
monkeypatch.setattr("codes.tune.optuna_fcts.set_random_seeds", lambda *a, **k: None)
219+
220+
monkeypatch.setattr("codes.tune.optuna_fcts.get_surrogate", lambda name: DummyModel)
221+
monkeypatch.setattr("codes.tune.optuna_fcts.get_model_config", lambda name, cfg: {})
222+
monkeypatch.setattr(
223+
"codes.tune.optuna_fcts.make_optuna_params", lambda trial, params: {}
224+
)
225+
# patch quantile
226+
import torch
227+
228+
monkeypatch.setattr(torch, "quantile", lambda x, q: torch.tensor(7.0))
229+
# Run
230+
trial = type("T", (object,), {"number": 0})()
231+
val = __import__("codes.tune.optuna_fcts", fromlist=["training_run"]).training_run(
232+
trial, "cpu", 0, config, "study1"
233+
)
234+
assert isinstance(val, float) and val == 7.0
235+
236+
237+
def test_training_run_multi_objective(monkeypatch, tmp_path):
238+
config = {
239+
"dataset": {
240+
"name": "ds",
241+
"log10_transform": False,
242+
"normalise": "none",
243+
"tolerance": 1e-3,
244+
},
245+
"seed": 456,
246+
"surrogate": {"name": "Surr"},
247+
"optuna_params": {},
248+
"batch_size": 8,
249+
"epochs": 5,
250+
"target_percentile": 0.5,
251+
"multi_objective": True,
252+
}
253+
# stub all as above
254+
monkeypatch.setattr("codes.tune.optuna_fcts.download_data", lambda *args, **k: None)
255+
import numpy as np
256+
257+
dummy_data = np.ones((3, 5, 1))
258+
dummy_params = np.ones((3, 2))
259+
dummy_timesteps = np.arange(5)
260+
dummy_info = {}
261+
monkeypatch.setattr(
262+
"codes.tune.optuna_fcts.check_and_load_data",
263+
lambda *args, **kw: (
264+
(dummy_data, dummy_data, None),
265+
(dummy_params, dummy_params, None),
266+
dummy_timesteps,
267+
None,
268+
dummy_info,
269+
None,
270+
),
271+
)
272+
monkeypatch.setattr(
273+
"codes.tune.optuna_fcts.get_data_subset",
274+
lambda *args, **kw: (
275+
(dummy_data, dummy_data),
276+
(dummy_params, dummy_params),
277+
dummy_timesteps,
278+
),
279+
)
280+
monkeypatch.setattr("codes.tune.optuna_fcts.set_random_seeds", lambda *a, **k: None)
281+
282+
class DummyModel2(DummyModel):
283+
def predict(self, loader, leave_log=False):
284+
import torch
285+
286+
return torch.zeros((3, 5, 1)), torch.ones((3, 5, 1))
287+
288+
monkeypatch.setattr(
289+
"codes.tune.optuna_fcts.get_surrogate", lambda name: DummyModel2
290+
)
291+
monkeypatch.setattr("codes.tune.optuna_fcts.get_model_config", lambda name, cfg: {})
292+
monkeypatch.setattr(
293+
"codes.tune.optuna_fcts.make_optuna_params", lambda trial, params: {}
294+
)
295+
monkeypatch.setattr(
296+
"codes.tune.optuna_fcts.measure_inference_time", lambda m, l: [1.0, 2.0, 3.0]
297+
)
298+
import torch
299+
300+
monkeypatch.setattr(torch, "quantile", lambda x, q: torch.tensor(5.0))
301+
302+
trial = type("T", (object,), {"number": 2})()
303+
val = __import__("codes.tune.optuna_fcts", fromlist=["training_run"]).training_run(
304+
trial, "cpu", 1, config, "study_2"
305+
)
306+
# expect (loss, mean_inference)
307+
assert isinstance(val, tuple) and val[0] == 5.0 and val[1] == pytest.approx(2.0)
308+
309+
310+
def test_create_objective_simple(monkeypatch):
311+
# stub training_run
312+
called = {}
313+
314+
def fake_run(trial, device, slot, config, name):
315+
called["args"] = (device, slot, config, name)
316+
return 42.0
317+
318+
monkeypatch.setattr("codes.tune.optuna_fcts.training_run", fake_run)
319+
320+
device_queue = queue.Queue()
321+
device_queue.put(("cpu", 0))
322+
config = {"dataset": {"name": "ds"}}
323+
obj = create_objective(config, "study1", device_queue)
324+
325+
class DummyT:
326+
number = 5
327+
328+
trial = DummyT()
329+
result = obj(trial)
330+
# training_run returned 42.0
331+
assert result == 42.0
332+
# device put back
333+
assert not device_queue.empty()
334+
dev, slot = device_queue.get()
335+
assert dev == "cpu" and slot == 0

0 commit comments

Comments
 (0)