|
| 1 | +import os |
| 2 | +import time |
| 3 | +import random |
| 4 | +import yaml |
| 5 | +import torch |
| 6 | +import numpy as np |
| 7 | +import pytest |
| 8 | + |
| 9 | +from codes.utils import ( |
| 10 | + read_yaml_config, |
| 11 | + time_execution, |
| 12 | + create_model_dir, |
| 13 | + get_progress_bar, |
| 14 | + load_and_save_config, |
| 15 | + set_random_seeds, |
| 16 | + nice_print, |
| 17 | + make_description, |
| 18 | + worker_init_fn, |
| 19 | + save_task_list, |
| 20 | + load_task_list, |
| 21 | + check_training_status, |
| 22 | + determine_batch_size, |
| 23 | + batch_factor_to_float, |
| 24 | +) |
| 25 | + |
| 26 | + |
| 27 | +def test_read_yaml_config_and_parse_for_none(tmp_path): |
| 28 | + data = { |
| 29 | + "a": "None", |
| 30 | + "b": 123, |
| 31 | + "c": {"d": "None", "e": "foo"}, |
| 32 | + } |
| 33 | + f = tmp_path / "cfg.yaml" |
| 34 | + f.write_text(yaml.safe_dump(data)) |
| 35 | + cfg = read_yaml_config(str(f)) |
| 36 | + assert cfg["a"] is None |
| 37 | + assert cfg["b"] == 123 |
| 38 | + assert cfg["c"]["d"] is None |
| 39 | + assert cfg["c"]["e"] == "foo" |
| 40 | + |
| 41 | + |
| 42 | +def test_time_execution_decorator(): |
| 43 | + @time_execution |
| 44 | + def foo(x, y): |
| 45 | + time.sleep(0.01) |
| 46 | + return x + y |
| 47 | + |
| 48 | + # before call |
| 49 | + assert foo.duration is None |
| 50 | + res = foo(2, 3) |
| 51 | + assert res == 5 |
| 52 | + # after call |
| 53 | + assert isinstance(foo.duration, float) |
| 54 | + assert foo.duration >= 0.01 |
| 55 | + |
| 56 | + |
| 57 | +def test_create_model_dir(tmp_path): |
| 58 | + base = str(tmp_path) |
| 59 | + out = create_model_dir(base_dir=base, subfolder="trained", unique_id="XYZ") |
| 60 | + assert os.path.isdir(out) |
| 61 | + # idempotent |
| 62 | + out2 = create_model_dir(base_dir=base, subfolder="trained", unique_id="XYZ") |
| 63 | + assert out2 == out |
| 64 | + |
| 65 | + |
| 66 | +def test_load_and_save_config(tmp_path, monkeypatch): |
| 67 | + # prepare a minimal yaml |
| 68 | + cfg = {"training_id": "T1", "foo": 42} |
| 69 | + cfg_path = tmp_path / "config.yaml" |
| 70 | + cfg_path.write_text(yaml.safe_dump(cfg)) |
| 71 | + monkeypatch.chdir(tmp_path) |
| 72 | + loaded = load_and_save_config(str(cfg_path), save=True) |
| 73 | + assert loaded["training_id"] == "T1" |
| 74 | + dest = tmp_path / "trained" / "T1" / "config.yaml" |
| 75 | + assert dest.exists() |
| 76 | + # test save=False |
| 77 | + cfg2 = load_and_save_config(str(cfg_path), save=False) |
| 78 | + assert cfg2["foo"] == 42 |
| 79 | + |
| 80 | + |
| 81 | +def test_set_random_seeds_reproducible(): |
| 82 | + # reseeding twice with the same seed should give the same sequence |
| 83 | + set_random_seeds(123, device="cpu") |
| 84 | + r1 = random.random() |
| 85 | + n1 = np.random.rand() |
| 86 | + t1 = torch.rand(1).item() |
| 87 | + |
| 88 | + set_random_seeds(123, device="cpu") |
| 89 | + r2 = random.random() |
| 90 | + n2 = np.random.rand() |
| 91 | + t2 = torch.rand(1).item() |
| 92 | + |
| 93 | + assert pytest.approx(r1) == r2 |
| 94 | + assert pytest.approx(n1) == n2 |
| 95 | + assert pytest.approx(t1) == t2 |
| 96 | + |
| 97 | + |
| 98 | +def test_nice_print(capsys): |
| 99 | + nice_print("Hello", width=20) |
| 100 | + out = capsys.readouterr().out |
| 101 | + lines = out.strip().splitlines() |
| 102 | + assert len(lines) == 3 |
| 103 | + assert lines[0].startswith("-" * 20) |
| 104 | + assert "Hello" in lines[1] |
| 105 | + |
| 106 | + |
| 107 | +def test_make_description_padding(): |
| 108 | + desc = make_description("mode", "cpu:0", "5", "SurrogateA") |
| 109 | + # should contain surrogate name left-justified |
| 110 | + assert desc.startswith("SurrogateA ") |
| 111 | + assert "(cpu:0)" in desc |
| 112 | + |
| 113 | + |
| 114 | +def test_get_progress_bar_and_worker_init_fn(monkeypatch): |
| 115 | + # first, check get_progress_bar |
| 116 | + bar = get_progress_bar(["t1", "t2", "t3", "t4"]) |
| 117 | + assert bar.total == 4 |
| 118 | + # its description should mention “Overall Progress” |
| 119 | + assert "Overall Progress" in bar.desc |
| 120 | + |
| 121 | + # now stub out numpy.seed so negative seeds don't error |
| 122 | + called = {} |
| 123 | + |
| 124 | + def fake_np_seed(s): |
| 125 | + called["seed"] = s |
| 126 | + |
| 127 | + monkeypatch.setattr(np.random, "seed", fake_np_seed) |
| 128 | + |
| 129 | + # set a known torch seed |
| 130 | + torch.manual_seed(42) |
| 131 | + # and invoke worker_init_fn |
| 132 | + worker_init_fn(0) |
| 133 | + |
| 134 | + # ensure np.random.seed was called with an integer |
| 135 | + assert "seed" in called |
| 136 | + assert isinstance(called["seed"], int) |
| 137 | + |
| 138 | + |
| 139 | +def test_save_and_load_task_list(tmp_path): |
| 140 | + tasks = [{"x": 1}, {"y": 2}] |
| 141 | + fp = tmp_path / "tasks.json" |
| 142 | + save_task_list(tasks, str(fp)) |
| 143 | + assert fp.exists() |
| 144 | + loaded = load_task_list(str(fp)) |
| 145 | + assert loaded == tasks |
| 146 | + # non-existent |
| 147 | + missing = load_task_list(str(tmp_path / "nope.json")) |
| 148 | + assert missing == [] |
| 149 | + |
| 150 | + |
| 151 | +def test_check_training_status_new(tmp_path, monkeypatch): |
| 152 | + cfg = {"training_id": tmp_path.name, "devices": ["cpu"]} |
| 153 | + monkeypatch.chdir(tmp_path) |
| 154 | + # no trained/<id> directory |
| 155 | + path, copy = check_training_status(cfg) |
| 156 | + assert path.endswith("trained/" + tmp_path.name + "/train_tasks.json") |
| 157 | + assert copy is True |
| 158 | + |
| 159 | + |
| 160 | +def test_check_training_status_existing_same(tmp_path, monkeypatch): |
| 161 | + # prepare a trained/<id>/config.yaml that *differs* so we hit the input() branch |
| 162 | + cfg = {"training_id": tmp_path.name, "devices": ["cpu"]} |
| 163 | + root = tmp_path / "trained" / tmp_path.name |
| 164 | + root.mkdir(parents=True) |
| 165 | + # saved config missing 'training_id' so triggers the "differs" logic |
| 166 | + sample = {"foo": 1, "devices": ["cpu"]} |
| 167 | + (root / "config.yaml").write_text(yaml.safe_dump(sample)) |
| 168 | + monkeypatch.chdir(tmp_path) |
| 169 | + |
| 170 | + # patch input() to say "yes, overwrite" |
| 171 | + monkeypatch.setenv("PYTEST_DISABLE_PLUGIN_AUTOLOAD", "1") # to quiet warnings |
| 172 | + monkeypatch.setenv("PYTHONIOENCODING", "utf-8") |
| 173 | + monkeypatch.setattr("builtins.input", lambda prompt="": "y") |
| 174 | + |
| 175 | + path, copy = check_training_status(cfg) |
| 176 | + # since we answered "y", we should get copy=True |
| 177 | + assert copy is True |
| 178 | + assert path.endswith(f"trained/{tmp_path.name}/train_tasks.json") |
| 179 | + |
| 180 | + |
| 181 | +def test_determine_batch_size(): |
| 182 | + cfg = {"batch_size": [10, 20], "surrogates": ["A", "B"]} |
| 183 | + # list mode |
| 184 | + assert determine_batch_size(cfg, 1, mode="", metric=0) == 20 |
| 185 | + # global single |
| 186 | + cfg2 = {"batch_size": 5, "surrogates": ["X"]} |
| 187 | + assert determine_batch_size(cfg2, 0, mode="", metric=0) == 5 |
| 188 | + # batchsize mode multiplies |
| 189 | + cfg3 = {"batch_size": 8, "surrogates": ["X"]} |
| 190 | + assert determine_batch_size(cfg3, 0, mode="batchsize", metric=3) == 24 |
| 191 | + # mismatch length |
| 192 | + with pytest.raises(ValueError): |
| 193 | + determine_batch_size({"batch_size": [1], "surrogates": ["A", "B"]}, 0, "", 0) |
| 194 | + |
| 195 | + |
| 196 | +@pytest.mark.parametrize( |
| 197 | + "inp,exp", |
| 198 | + [ |
| 199 | + (0.5, 0.5), |
| 200 | + (2, 2.0), |
| 201 | + ("3.14", 3.14), |
| 202 | + ("1/4", 0.25), |
| 203 | + ], |
| 204 | +) |
| 205 | +def test_batch_factor_to_float_valid(inp, exp): |
| 206 | + assert batch_factor_to_float(inp) == pytest.approx(exp) |
| 207 | + |
| 208 | + |
| 209 | +def test_batch_factor_to_float_invalid(): |
| 210 | + with pytest.raises(ValueError): |
| 211 | + batch_factor_to_float("not_a_number") |
0 commit comments