1616
1717from ml_grid .util .project_score_save import project_score_save_class
1818
19+
1920class TestProjectScoreSave (unittest .TestCase ):
2021
2122 def setUp (self ):
2223 # Create a temporary directory for the experiment to avoid cluttering disk
2324 self .test_dir = tempfile .mkdtemp ()
2425 self .experiment_dir = Path (self .test_dir ) / "test_experiment"
25-
26+
2627 # Patch global_parameters to control configuration during tests
2728 self .patcher = patch ("ml_grid.util.project_score_save.global_parameters" )
2829 self .mock_globals = self .patcher .start ()
29-
30+
3031 # Default mock configuration
3132 self .mock_globals .metric_list = {"auc" : "auc" , "accuracy" : "accuracy" }
32- self .mock_globals .error_raise = True # Important: Raise errors so tests fail on bugs
33+ self .mock_globals .error_raise = (
34+ True # Important: Raise errors so tests fail on bugs
35+ )
3336 self .mock_globals .bayessearch = False
3437 self .mock_globals .store_models = False
3538
@@ -40,10 +43,10 @@ def tearDown(self):
4043 def test_initialization (self ):
4144 """Test that the log file is created with correct headers."""
4245 saver = project_score_save_class (str (self .experiment_dir ))
43-
46+
4447 log_path = self .experiment_dir / "final_grid_score_log.csv"
4548 self .assertTrue (log_path .exists (), "Log file was not created" )
46-
49+
4750 df = pd .read_csv (log_path )
4851 expected_cols = ["algorithm_implementation" , "auc_m" , "accuracy_m" ]
4952 for col in expected_cols :
@@ -52,7 +55,7 @@ def test_initialization(self):
5255 def test_update_score_log_success (self ):
5356 """Test a successful write to the log file with all attributes present."""
5457 saver = project_score_save_class (str (self .experiment_dir ))
55-
58+
5659 # Mock the ml_grid_object with all expected attributes
5760 mock_grid = MagicMock ()
5861 mock_grid .X_train = [1 , 2 ]
@@ -63,16 +66,18 @@ def test_update_score_log_success(self):
6366 mock_grid .y_test_orig = [1 , 0 ]
6467 mock_grid .param_space_index = 1
6568 mock_grid .outcome_variable = "target"
66-
69+
6770 # Attributes that caused issues previously
6871 mock_grid .local_param_dict = {"param1" : 10 }
6972 mock_grid .final_column_list = ["col1" ]
7073 mock_grid .original_feature_names = ["col1" , "col2" ]
7174
7275 # Mock scores and algorithm
7376 scores = {
74- "fit_time" : [0.1 ], "score_time" : [0.01 ],
75- "test_auc" : [0.8 ], "test_accuracy" : [0.9 ]
77+ "fit_time" : [0.1 ],
78+ "score_time" : [0.01 ],
79+ "test_auc" : [0.8 ],
80+ "test_accuracy" : [0.9 ],
7681 }
7782 best_pred = np .array ([1 , 0 ])
7883 algo = MagicMock ()
@@ -87,7 +92,7 @@ def test_update_score_log_success(self):
8792 pg = 10 ,
8893 start = 0 ,
8994 n_iter_v = 5 ,
90- failed = False
95+ failed = False ,
9196 )
9297
9398 # Verify data was written
@@ -100,32 +105,40 @@ def test_update_score_log_success(self):
100105 def test_update_score_log_typo_and_missing_safety (self ):
101106 """Test that the code handles missing attributes and the 'orignal' typo."""
102107 saver = project_score_save_class (str (self .experiment_dir ))
103-
108+
104109 mock_grid = MagicMock ()
105110 # Minimal setup
106111 mock_grid .y_test = pd .Series ([1 , 0 ])
107112 mock_grid .param_space_index = 1
108-
113+
109114 # Simulate missing local_param_dict (should default to {})
110115 del mock_grid .local_param_dict
111-
116+
112117 # Simulate the typo: 'original' missing, 'orignal' present
113118 del mock_grid .original_feature_names
114- mock_grid .orignal_feature_names = ["col1" ]
119+ mock_grid .orignal_feature_names = ["col1" ]
115120 mock_grid .final_column_list = ["col1" ]
116121
117- scores = {"fit_time" : [0.1 ], "score_time" : [0.01 ], "test_auc" : [0.5 ], "test_accuracy" : [0.5 ]}
118-
122+ scores = {
123+ "fit_time" : [0.1 ],
124+ "score_time" : [0.01 ],
125+ "test_auc" : [0.5 ],
126+ "test_accuracy" : [0.5 ],
127+ }
128+
119129 # Should not raise AttributeError
120130 saver .update_score_log (
121131 ml_grid_object = mock_grid ,
122132 scores = scores ,
123133 best_pred_orig = np .array ([1 , 0 ]),
124134 current_algorithm = MagicMock (),
125135 method_name = "TypoTest" ,
126- pg = 1 , start = 0 , n_iter_v = 1 , failed = False
136+ pg = 1 ,
137+ start = 0 ,
138+ n_iter_v = 1 ,
139+ failed = False ,
127140 )
128-
141+
129142 log_path = self .experiment_dir / "final_grid_score_log.csv"
130143 df = pd .read_csv (log_path )
131144 self .assertEqual (len (df ), 1 )
@@ -135,18 +148,19 @@ def test_initialization_does_not_overwrite(self):
135148 # First initialization
136149 saver1 = project_score_save_class (str (self .experiment_dir ))
137150 log_path = self .experiment_dir / "final_grid_score_log.csv"
138-
151+
139152 # Simulate writing some data
140153 with open (log_path , "a" ) as f :
141154 f .write ("test_data_entry\n " )
142-
155+
143156 # Second initialization on same directory
144157 saver2 = project_score_save_class (str (self .experiment_dir ))
145-
158+
146159 # Verify data persists
147160 with open (log_path , "r" ) as f :
148161 content = f .read ()
149162 self .assertIn ("test_data_entry" , content )
150163
164+
151165if __name__ == "__main__" :
152- unittest .main ()
166+ unittest .main ()
0 commit comments