|
| 1 | +"""Tests for affine calibration: apply_affine, load_calibration, save_calibration.""" |
| 2 | + |
| 3 | +import json |
| 4 | +from pathlib import Path |
| 5 | + |
| 6 | +import pytest |
| 7 | + |
| 8 | +from tiny_icf.calibration import apply_affine, load_calibration, save_calibration |
| 9 | + |
| 10 | + |
| 11 | +def test_apply_affine_formula(): |
| 12 | + """apply_affine(pred, a, b) == clip(a + b * pred, 0, 1).""" |
| 13 | + a, b = 0.1, 0.9 |
| 14 | + assert apply_affine(0.0, a, b) == pytest.approx(0.1) |
| 15 | + assert apply_affine(1.0, a, b) == pytest.approx(1.0) |
| 16 | + assert apply_affine(0.5, a, b) == pytest.approx(0.1 + 0.9 * 0.5) |
| 17 | + |
| 18 | + |
| 19 | +def test_apply_affine_clipping(): |
| 20 | + """Output is clipped to [0, 1].""" |
| 21 | + assert apply_affine(2.0, 0.0, 1.0) == 1.0 |
| 22 | + assert apply_affine(-1.0, 0.0, 1.0) == 0.0 |
| 23 | + assert apply_affine(0.5, 1.0, 1.0) == 1.0 |
| 24 | + assert apply_affine(0.5, -0.5, 0.5) == 0.0 |
| 25 | + |
| 26 | + |
| 27 | +def test_save_load_roundtrip(tmp_path: Path): |
| 28 | + """save_calibration then load_calibration returns same (a, b).""" |
| 29 | + cal_path = tmp_path / "model.pt.cal.json" |
| 30 | + save_calibration(cal_path, 0.05, 0.95) |
| 31 | + loaded = load_calibration(cal_path) |
| 32 | + assert loaded is not None |
| 33 | + assert loaded == (0.05, 0.95) |
| 34 | + data = json.loads(cal_path.read_text()) |
| 35 | + assert data["a"] == 0.05 and data["b"] == 0.95 |
| 36 | + |
| 37 | + |
| 38 | +def test_load_calibration_missing_returns_none(tmp_path: Path): |
| 39 | + """load_calibration on missing file returns None.""" |
| 40 | + assert load_calibration(tmp_path / "nonexistent.cal.json") is None |
| 41 | + |
| 42 | + |
| 43 | +def test_load_calibration_invalid_returns_none(tmp_path: Path): |
| 44 | + """load_calibration on invalid JSON returns None.""" |
| 45 | + bad = tmp_path / "bad.cal.json" |
| 46 | + bad.write_text("not json") |
| 47 | + assert load_calibration(bad) is None |
| 48 | + bad.write_text("{}") |
| 49 | + # empty dict: get("a", 0.0), get("b", 1.0) so (0.0, 1.0) is valid |
| 50 | + assert load_calibration(bad) == (0.0, 1.0) |
0 commit comments