Skip to content

Commit 396247d

Browse files
committed
test utils
1 parent 34ab7d2 commit 396247d

1 file changed

Lines changed: 211 additions & 0 deletions

File tree

test/test_utils.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
import os
2+
import time
3+
import random
4+
import yaml
5+
import torch
6+
import numpy as np
7+
import pytest
8+
9+
from codes.utils import (
10+
read_yaml_config,
11+
time_execution,
12+
create_model_dir,
13+
get_progress_bar,
14+
load_and_save_config,
15+
set_random_seeds,
16+
nice_print,
17+
make_description,
18+
worker_init_fn,
19+
save_task_list,
20+
load_task_list,
21+
check_training_status,
22+
determine_batch_size,
23+
batch_factor_to_float,
24+
)
25+
26+
27+
def test_read_yaml_config_and_parse_for_none(tmp_path):
28+
data = {
29+
"a": "None",
30+
"b": 123,
31+
"c": {"d": "None", "e": "foo"},
32+
}
33+
f = tmp_path / "cfg.yaml"
34+
f.write_text(yaml.safe_dump(data))
35+
cfg = read_yaml_config(str(f))
36+
assert cfg["a"] is None
37+
assert cfg["b"] == 123
38+
assert cfg["c"]["d"] is None
39+
assert cfg["c"]["e"] == "foo"
40+
41+
42+
def test_time_execution_decorator():
43+
@time_execution
44+
def foo(x, y):
45+
time.sleep(0.01)
46+
return x + y
47+
48+
# before call
49+
assert foo.duration is None
50+
res = foo(2, 3)
51+
assert res == 5
52+
# after call
53+
assert isinstance(foo.duration, float)
54+
assert foo.duration >= 0.01
55+
56+
57+
def test_create_model_dir(tmp_path):
58+
base = str(tmp_path)
59+
out = create_model_dir(base_dir=base, subfolder="trained", unique_id="XYZ")
60+
assert os.path.isdir(out)
61+
# idempotent
62+
out2 = create_model_dir(base_dir=base, subfolder="trained", unique_id="XYZ")
63+
assert out2 == out
64+
65+
66+
def test_load_and_save_config(tmp_path, monkeypatch):
67+
# prepare a minimal yaml
68+
cfg = {"training_id": "T1", "foo": 42}
69+
cfg_path = tmp_path / "config.yaml"
70+
cfg_path.write_text(yaml.safe_dump(cfg))
71+
monkeypatch.chdir(tmp_path)
72+
loaded = load_and_save_config(str(cfg_path), save=True)
73+
assert loaded["training_id"] == "T1"
74+
dest = tmp_path / "trained" / "T1" / "config.yaml"
75+
assert dest.exists()
76+
# test save=False
77+
cfg2 = load_and_save_config(str(cfg_path), save=False)
78+
assert cfg2["foo"] == 42
79+
80+
81+
def test_set_random_seeds_reproducible():
82+
# reseeding twice with the same seed should give the same sequence
83+
set_random_seeds(123, device="cpu")
84+
r1 = random.random()
85+
n1 = np.random.rand()
86+
t1 = torch.rand(1).item()
87+
88+
set_random_seeds(123, device="cpu")
89+
r2 = random.random()
90+
n2 = np.random.rand()
91+
t2 = torch.rand(1).item()
92+
93+
assert pytest.approx(r1) == r2
94+
assert pytest.approx(n1) == n2
95+
assert pytest.approx(t1) == t2
96+
97+
98+
def test_nice_print(capsys):
99+
nice_print("Hello", width=20)
100+
out = capsys.readouterr().out
101+
lines = out.strip().splitlines()
102+
assert len(lines) == 3
103+
assert lines[0].startswith("-" * 20)
104+
assert "Hello" in lines[1]
105+
106+
107+
def test_make_description_padding():
108+
desc = make_description("mode", "cpu:0", "5", "SurrogateA")
109+
# should contain surrogate name left-justified
110+
assert desc.startswith("SurrogateA ")
111+
assert "(cpu:0)" in desc
112+
113+
114+
def test_get_progress_bar_and_worker_init_fn(monkeypatch):
115+
# first, check get_progress_bar
116+
bar = get_progress_bar(["t1", "t2", "t3", "t4"])
117+
assert bar.total == 4
118+
# its description should mention “Overall Progress”
119+
assert "Overall Progress" in bar.desc
120+
121+
# now stub out numpy.seed so negative seeds don't error
122+
called = {}
123+
124+
def fake_np_seed(s):
125+
called["seed"] = s
126+
127+
monkeypatch.setattr(np.random, "seed", fake_np_seed)
128+
129+
# set a known torch seed
130+
torch.manual_seed(42)
131+
# and invoke worker_init_fn
132+
worker_init_fn(0)
133+
134+
# ensure np.random.seed was called with an integer
135+
assert "seed" in called
136+
assert isinstance(called["seed"], int)
137+
138+
139+
def test_save_and_load_task_list(tmp_path):
140+
tasks = [{"x": 1}, {"y": 2}]
141+
fp = tmp_path / "tasks.json"
142+
save_task_list(tasks, str(fp))
143+
assert fp.exists()
144+
loaded = load_task_list(str(fp))
145+
assert loaded == tasks
146+
# non-existent
147+
missing = load_task_list(str(tmp_path / "nope.json"))
148+
assert missing == []
149+
150+
151+
def test_check_training_status_new(tmp_path, monkeypatch):
152+
cfg = {"training_id": tmp_path.name, "devices": ["cpu"]}
153+
monkeypatch.chdir(tmp_path)
154+
# no trained/<id> directory
155+
path, copy = check_training_status(cfg)
156+
assert path.endswith("trained/" + tmp_path.name + "/train_tasks.json")
157+
assert copy is True
158+
159+
160+
def test_check_training_status_existing_same(tmp_path, monkeypatch):
161+
# prepare a trained/<id>/config.yaml that *differs* so we hit the input() branch
162+
cfg = {"training_id": tmp_path.name, "devices": ["cpu"]}
163+
root = tmp_path / "trained" / tmp_path.name
164+
root.mkdir(parents=True)
165+
# saved config missing 'training_id' so triggers the "differs" logic
166+
sample = {"foo": 1, "devices": ["cpu"]}
167+
(root / "config.yaml").write_text(yaml.safe_dump(sample))
168+
monkeypatch.chdir(tmp_path)
169+
170+
# patch input() to say "yes, overwrite"
171+
monkeypatch.setenv("PYTEST_DISABLE_PLUGIN_AUTOLOAD", "1") # to quiet warnings
172+
monkeypatch.setenv("PYTHONIOENCODING", "utf-8")
173+
monkeypatch.setattr("builtins.input", lambda prompt="": "y")
174+
175+
path, copy = check_training_status(cfg)
176+
# since we answered "y", we should get copy=True
177+
assert copy is True
178+
assert path.endswith(f"trained/{tmp_path.name}/train_tasks.json")
179+
180+
181+
def test_determine_batch_size():
182+
cfg = {"batch_size": [10, 20], "surrogates": ["A", "B"]}
183+
# list mode
184+
assert determine_batch_size(cfg, 1, mode="", metric=0) == 20
185+
# global single
186+
cfg2 = {"batch_size": 5, "surrogates": ["X"]}
187+
assert determine_batch_size(cfg2, 0, mode="", metric=0) == 5
188+
# batchsize mode multiplies
189+
cfg3 = {"batch_size": 8, "surrogates": ["X"]}
190+
assert determine_batch_size(cfg3, 0, mode="batchsize", metric=3) == 24
191+
# mismatch length
192+
with pytest.raises(ValueError):
193+
determine_batch_size({"batch_size": [1], "surrogates": ["A", "B"]}, 0, "", 0)
194+
195+
196+
@pytest.mark.parametrize(
197+
"inp,exp",
198+
[
199+
(0.5, 0.5),
200+
(2, 2.0),
201+
("3.14", 3.14),
202+
("1/4", 0.25),
203+
],
204+
)
205+
def test_batch_factor_to_float_valid(inp, exp):
206+
assert batch_factor_to_float(inp) == pytest.approx(exp)
207+
208+
209+
def test_batch_factor_to_float_invalid():
210+
with pytest.raises(ValueError):
211+
batch_factor_to_float("not_a_number")

0 commit comments

Comments
 (0)