Skip to content

Commit dcf4bf8

Browse files
committed
unit test for score saving
1 parent 087bed7 commit dcf4bf8

1 file changed

Lines changed: 134 additions & 0 deletions

File tree

tests/test_project_score_save.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import unittest
2+
import shutil
3+
import tempfile
4+
import pandas as pd
5+
import numpy as np
6+
from pathlib import Path
7+
from unittest.mock import MagicMock, patch
8+
import sys
9+
import os
10+
11+
# Ensure the project root is in sys.path so we can import the module
12+
current_dir = os.path.dirname(os.path.abspath(__file__))
13+
project_root = os.path.abspath(os.path.join(current_dir, ".."))
14+
if project_root not in sys.path:
15+
sys.path.insert(0, project_root)
16+
17+
from ml_grid.util.project_score_save import project_score_save_class
18+
19+
class TestProjectScoreSave(unittest.TestCase):
20+
21+
def setUp(self):
22+
# Create a temporary directory for the experiment to avoid cluttering disk
23+
self.test_dir = tempfile.mkdtemp()
24+
self.experiment_dir = Path(self.test_dir) / "test_experiment"
25+
26+
# Patch global_parameters to control configuration during tests
27+
self.patcher = patch("ml_grid.util.project_score_save.global_parameters")
28+
self.mock_globals = self.patcher.start()
29+
30+
# Default mock configuration
31+
self.mock_globals.metric_list = {"auc": "auc", "accuracy": "accuracy"}
32+
self.mock_globals.error_raise = True # Important: Raise errors so tests fail on bugs
33+
self.mock_globals.bayessearch = False
34+
self.mock_globals.store_models = False
35+
36+
def tearDown(self):
37+
self.patcher.stop()
38+
shutil.rmtree(self.test_dir)
39+
40+
def test_initialization(self):
41+
"""Test that the log file is created with correct headers."""
42+
saver = project_score_save_class(str(self.experiment_dir))
43+
44+
log_path = self.experiment_dir / "final_grid_score_log.csv"
45+
self.assertTrue(log_path.exists(), "Log file was not created")
46+
47+
df = pd.read_csv(log_path)
48+
expected_cols = ["algorithm_implementation", "auc_m", "accuracy_m"]
49+
for col in expected_cols:
50+
self.assertIn(col, df.columns, f"Missing column: {col}")
51+
52+
def test_update_score_log_success(self):
53+
"""Test a successful write to the log file with all attributes present."""
54+
saver = project_score_save_class(str(self.experiment_dir))
55+
56+
# Mock the ml_grid_object with all expected attributes
57+
mock_grid = MagicMock()
58+
mock_grid.X_train = [1, 2]
59+
mock_grid.y_train = [0, 1]
60+
mock_grid.X_test = pd.DataFrame({"a": [1]})
61+
mock_grid.y_test = pd.Series([1, 0])
62+
mock_grid.X_test_orig = [1, 2]
63+
mock_grid.y_test_orig = [1, 0]
64+
mock_grid.param_space_index = 1
65+
mock_grid.outcome_variable = "target"
66+
67+
# Attributes that caused issues previously
68+
mock_grid.local_param_dict = {"param1": 10}
69+
mock_grid.final_column_list = ["col1"]
70+
mock_grid.original_feature_names = ["col1", "col2"]
71+
72+
# Mock scores and algorithm
73+
scores = {
74+
"fit_time": [0.1], "score_time": [0.01],
75+
"test_auc": [0.8], "test_accuracy": [0.9]
76+
}
77+
best_pred = np.array([1, 0])
78+
algo = MagicMock()
79+
algo.get_params.return_value = {"p": 1}
80+
81+
saver.update_score_log(
82+
ml_grid_object=mock_grid,
83+
scores=scores,
84+
best_pred_orig=best_pred,
85+
current_algorithm=algo,
86+
method_name="TestAlgo",
87+
pg=10,
88+
start=0,
89+
n_iter_v=5,
90+
failed=False
91+
)
92+
93+
# Verify data was written
94+
log_path = self.experiment_dir / "final_grid_score_log.csv"
95+
df = pd.read_csv(log_path)
96+
self.assertEqual(len(df), 1)
97+
self.assertEqual(df.iloc[0]["method_name"], "TestAlgo")
98+
self.assertEqual(df.iloc[0]["auc_m"], 0.8)
99+
100+
def test_update_score_log_typo_and_missing_safety(self):
101+
"""Test that the code handles missing attributes and the 'orignal' typo."""
102+
saver = project_score_save_class(str(self.experiment_dir))
103+
104+
mock_grid = MagicMock()
105+
# Minimal setup
106+
mock_grid.y_test = pd.Series([1, 0])
107+
mock_grid.param_space_index = 1
108+
109+
# Simulate missing local_param_dict (should default to {})
110+
del mock_grid.local_param_dict
111+
112+
# Simulate the typo: 'original' missing, 'orignal' present
113+
del mock_grid.original_feature_names
114+
mock_grid.orignal_feature_names = ["col1"]
115+
mock_grid.final_column_list = ["col1"]
116+
117+
scores = {"fit_time": [0.1], "score_time": [0.01], "test_auc": [0.5], "test_accuracy": [0.5]}
118+
119+
# Should not raise AttributeError
120+
saver.update_score_log(
121+
ml_grid_object=mock_grid,
122+
scores=scores,
123+
best_pred_orig=np.array([1, 0]),
124+
current_algorithm=MagicMock(),
125+
method_name="TypoTest",
126+
pg=1, start=0, n_iter_v=1, failed=False
127+
)
128+
129+
log_path = self.experiment_dir / "final_grid_score_log.csv"
130+
df = pd.read_csv(log_path)
131+
self.assertEqual(len(df), 1)
132+
133+
if __name__ == "__main__":
134+
unittest.main()

0 commit comments

Comments
 (0)