Skip to content

Commit 5610828

Browse files
committed
test data pipeline
1 parent df761ca commit 5610828

1 file changed

Lines changed: 216 additions & 0 deletions

File tree

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
"""
2+
Unit tests for the ml_grid.pipeline.data.pipe class.
3+
4+
This test suite validates the core functionality of the data pipeline, ensuring
5+
that data is loaded, cleaned, transformed, and split correctly according to
6+
various configurations.
7+
"""
8+
9+
import unittest
10+
import pandas as pd
11+
import numpy as np
12+
import os
13+
import tempfile
14+
import shutil
15+
from pathlib import Path
16+
17+
# Ensure the project root is in the Python path to allow for module imports
18+
try:
19+
from ml_grid.pipeline.data import pipe, NoFeaturesError
20+
from ml_grid.util.global_params import global_parameters
21+
except ImportError:
22+
# This allows the test to be run from the project root directory
23+
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
24+
from ml_grid.pipeline.data import pipe, NoFeaturesError
25+
from ml_grid.util.global_params import global_parameters
26+
27+
28+
class TestDataPipeline(unittest.TestCase):
29+
"""Test suite for the data.pipe class."""
30+
31+
def setUp(self):
32+
"""Set up a temporary environment for each test."""
33+
self.project_root = Path(__file__).resolve().parents[1]
34+
self.test_dir = tempfile.mkdtemp()
35+
36+
# Use the provided test data file
37+
self.test_data_path = self.project_root / "notebooks" / "test_data_hfe_1yr_m_small_multiclass.csv"
38+
if not self.test_data_path.exists():
39+
self.fail(f"Test data file not found at {self.test_data_path}")
40+
41+
# Configure global parameters for testing
42+
global_parameters.verbose = 0 # Keep test output clean
43+
global_parameters.error_raise = True
44+
global_parameters.bayessearch = False # Explicitly set search mode
45+
46+
# Define a base configuration for the pipeline
47+
self.base_local_param_dict = {
48+
'outcome_var_n': 1,
49+
'param_space_size': 'small',
50+
'scale': True,
51+
'feature_n': 100, # Use all features by default
52+
'use_embedding': False,
53+
'embedding_method': 'pca',
54+
'embedding_dim': 10,
55+
'scale_features_before_embedding': True,
56+
'percent_missing': 50,
57+
'correlation_threshold': 0.98,
58+
'corr': 0.98,
59+
'test_size': 0.25,
60+
'resample': None,
61+
'random_state': 42,
62+
'feature_selection_method': 'anova',
63+
'data': {
64+
'age': True, 'sex': True, 'bmi': True, 'ethnicity': True,
65+
'bloods': True, 'diagnostic_order': True, 'drug_order': True,
66+
'annotation_n': True, 'meta_sp_annotation_n': True,
67+
'annotation_mrc_n': True, 'meta_sp_annotation_mrc_n': True,
68+
'core_02': True, 'bed': True, 'vte_status': True,
69+
'hosp_site': True, 'core_resus': True, 'news': True,
70+
'date_time_stamp': True, 'appointments': True,
71+
}
72+
}
73+
self.drop_term_list = ['chrom', 'hfe', 'phlebo']
74+
self.model_class_dict = {'LogisticRegression_class': True}
75+
76+
def tearDown(self):
77+
"""Clean up the temporary directory after each test."""
78+
shutil.rmtree(self.test_dir)
79+
80+
def test_pipeline_initialization_successful(self):
81+
"""Test that the pipeline initializes and runs without errors."""
82+
try:
83+
pipeline = pipe(
84+
file_name=str(self.test_data_path),
85+
drop_term_list=self.drop_term_list,
86+
experiment_dir=self.test_dir,
87+
base_project_dir=str(self.project_root),
88+
local_param_dict=self.base_local_param_dict,
89+
param_space_index=0,
90+
model_class_dict=self.model_class_dict
91+
)
92+
# Assert that key attributes are created and have the correct types
93+
self.assertIsInstance(pipeline.X_train, pd.DataFrame)
94+
self.assertIsInstance(pipeline.y_train, pd.Series)
95+
self.assertGreater(len(pipeline.final_column_list), 0)
96+
self.assertGreater(len(pipeline.model_class_list), 0)
97+
self.assertEqual(pipeline.outcome_variable, 'outcome_var_1')
98+
99+
except Exception as e:
100+
self.fail(f"Pipeline initialization failed with an unexpected error: {e}")
101+
102+
def test_no_constant_columns_in_final_X_train(self):
103+
"""Verify that the final X_train contains no constant columns."""
104+
pipeline = pipe(
105+
file_name=str(self.test_data_path),
106+
drop_term_list=self.drop_term_list,
107+
experiment_dir=self.test_dir,
108+
base_project_dir=str(self.project_root),
109+
local_param_dict=self.base_local_param_dict,
110+
param_space_index=1,
111+
model_class_dict=self.model_class_dict
112+
)
113+
# A constant column has a variance of 0
114+
variances = pipeline.X_train.var(axis=0)
115+
constant_columns = variances[variances == 0].index.tolist()
116+
self.assertEqual(len(constant_columns), 0, f"Found constant columns in final X_train: {constant_columns}")
117+
118+
def test_data_quality_in_final_data(self):
119+
"""Check for NaN or infinite values in the final training data."""
120+
pipeline = pipe(
121+
file_name=str(self.test_data_path),
122+
drop_term_list=self.drop_term_list,
123+
experiment_dir=self.test_dir,
124+
base_project_dir=str(self.project_root),
125+
local_param_dict=self.base_local_param_dict,
126+
param_space_index=2,
127+
model_class_dict=self.model_class_dict
128+
)
129+
self.assertEqual(pipeline.X_train.isna().sum().sum(), 0, "Found NaN values in final X_train.")
130+
self.assertEqual(np.isinf(pipeline.X_train.select_dtypes(include=np.number)).sum().sum(), 0, "Found infinite values in final X_train.")
131+
132+
def test_feature_importance_selection(self):
133+
"""Test that feature importance selection correctly reduces column count."""
134+
params = self.base_local_param_dict.copy()
135+
params['feature_n'] = 50 # Select top 50% of features
136+
137+
pipeline = pipe(
138+
file_name=str(self.test_data_path),
139+
drop_term_list=self.drop_term_list,
140+
experiment_dir=self.test_dir,
141+
base_project_dir=str(self.project_root),
142+
local_param_dict=params,
143+
param_space_index=3,
144+
model_class_dict=self.model_class_dict
145+
)
146+
147+
# Get the number of features *before* importance selection
148+
log = pipeline.feature_transformation_log
149+
features_before_importance = log[log['step'] == 'Feature Importance']['features_before'].iloc[0]
150+
151+
expected_features = int(features_before_importance * 0.50)
152+
# Allow for slight rounding differences
153+
self.assertAlmostEqual(pipeline.X_train.shape[1], expected_features, delta=1,
154+
msg="Feature importance did not reduce features to ~50%.")
155+
156+
def test_embedding_application(self):
157+
"""Test that embedding correctly reduces features to the target dimension."""
158+
params = self.base_local_param_dict.copy()
159+
params['use_embedding'] = True
160+
params['embedding_dim'] = 15
161+
162+
pipeline = pipe(
163+
file_name=str(self.test_data_path),
164+
drop_term_list=self.drop_term_list,
165+
experiment_dir=self.test_dir,
166+
base_project_dir=str(self.project_root),
167+
local_param_dict=params,
168+
param_space_index=4,
169+
model_class_dict=self.model_class_dict
170+
)
171+
172+
self.assertEqual(pipeline.X_train.shape[1], params['embedding_dim'],
173+
"Embedding did not reduce features to the target embedding_dim.")
174+
self.assertTrue(all(c.startswith('embed_') for c in pipeline.X_train.columns))
175+
176+
def test_index_alignment(self):
177+
"""Test that all final data splits have aligned indices."""
178+
pipeline = pipe(
179+
file_name=str(self.test_data_path),
180+
drop_term_list=self.drop_term_list,
181+
experiment_dir=self.test_dir,
182+
base_project_dir=str(self.project_root),
183+
local_param_dict=self.base_local_param_dict,
184+
param_space_index=5,
185+
model_class_dict=self.model_class_dict
186+
)
187+
self.assertTrue(pipeline.X_train.index.equals(pipeline.y_train.index), "X_train and y_train indices are not aligned.")
188+
self.assertTrue(pipeline.X_test.index.equals(pipeline.y_test.index), "X_test and y_test indices are not aligned.")
189+
self.assertTrue(pipeline.X_test_orig.index.equals(pipeline.y_test_orig.index), "X_test_orig and y_test_orig indices are not aligned.")
190+
191+
def test_safety_net_activation(self):
192+
"""Test that the safety net retains features when all are pruned."""
193+
params = self.base_local_param_dict.copy()
194+
# Create a config that will prune all features
195+
params['data'] = {key: False for key in params['data']}
196+
params['percent_missing'] = 0 # Drop any column with missing values
197+
params['correlation_threshold'] = 0.01 # Drop almost everything
198+
199+
pipeline = pipe(
200+
file_name=str(self.test_data_path),
201+
drop_term_list=self.drop_term_list,
202+
experiment_dir=self.test_dir,
203+
base_project_dir=str(self.project_root),
204+
local_param_dict=params,
205+
param_space_index=6,
206+
model_class_dict=self.model_class_dict
207+
)
208+
209+
# Check that the safety net was activated and retained some features
210+
log = pipeline.feature_transformation_log
211+
self.assertTrue('Safety Net' in log['step'].values, "Safety Net step was not logged.")
212+
self.assertGreater(pipeline.X_train.shape[1], 0, "Safety net failed to retain any features.")
213+
214+
215+
if __name__ == '__main__':
216+
unittest.main(argv=['first-arg-is-ignored'], exit=False)

0 commit comments

Comments
 (0)