Skip to content

Commit ce6f8c3

Browse files
author
Henry Wallace
committed
tests: add test_calibration.py for apply_affine, save/load round-trip, missing/invalid
1 parent 04c76e3 commit ce6f8c3

1 file changed

Lines changed: 50 additions & 0 deletions

File tree

tests/test_calibration.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""Tests for affine calibration: apply_affine, load_calibration, save_calibration."""
2+
3+
import json
4+
from pathlib import Path
5+
6+
import pytest
7+
8+
from tiny_icf.calibration import apply_affine, load_calibration, save_calibration
9+
10+
11+
def test_apply_affine_formula():
12+
"""apply_affine(pred, a, b) == clip(a + b * pred, 0, 1)."""
13+
a, b = 0.1, 0.9
14+
assert apply_affine(0.0, a, b) == pytest.approx(0.1)
15+
assert apply_affine(1.0, a, b) == pytest.approx(1.0)
16+
assert apply_affine(0.5, a, b) == pytest.approx(0.1 + 0.9 * 0.5)
17+
18+
19+
def test_apply_affine_clipping():
20+
"""Output is clipped to [0, 1]."""
21+
assert apply_affine(2.0, 0.0, 1.0) == 1.0
22+
assert apply_affine(-1.0, 0.0, 1.0) == 0.0
23+
assert apply_affine(0.5, 1.0, 1.0) == 1.0
24+
assert apply_affine(0.5, -0.5, 0.5) == 0.0
25+
26+
27+
def test_save_load_roundtrip(tmp_path: Path):
28+
"""save_calibration then load_calibration returns same (a, b)."""
29+
cal_path = tmp_path / "model.pt.cal.json"
30+
save_calibration(cal_path, 0.05, 0.95)
31+
loaded = load_calibration(cal_path)
32+
assert loaded is not None
33+
assert loaded == (0.05, 0.95)
34+
data = json.loads(cal_path.read_text())
35+
assert data["a"] == 0.05 and data["b"] == 0.95
36+
37+
38+
def test_load_calibration_missing_returns_none(tmp_path: Path):
39+
"""load_calibration on missing file returns None."""
40+
assert load_calibration(tmp_path / "nonexistent.cal.json") is None
41+
42+
43+
def test_load_calibration_invalid_returns_none(tmp_path: Path):
44+
"""load_calibration on invalid JSON returns None."""
45+
bad = tmp_path / "bad.cal.json"
46+
bad.write_text("not json")
47+
assert load_calibration(bad) is None
48+
bad.write_text("{}")
49+
# empty dict: get("a", 0.0), get("b", 1.0) so (0.0, 1.0) is valid
50+
assert load_calibration(bad) == (0.0, 1.0)

0 commit comments

Comments
 (0)