1+ import unittest
2+ import pandas as pd
3+ import numpy as np
4+ from ml_grid .pipeline .data_constant_columns import remove_constant_columns , remove_constant_columns_with_debug
5+
6+ class TestRemoveConstantColumns (unittest .TestCase ):
7+
8+ def test_remove_constant_columns_with_constants (self ):
9+ """Test that constant columns are identified and added to the drop list."""
10+ df = pd .DataFrame ({
11+ 'a' : [1 , 2 , 3 ],
12+ 'b' : [5 , 5 , 5 ],
13+ 'c' : ['x' , 'y' , 'z' ],
14+ 'd' : [0 , 0 , 0 ]
15+ })
16+ initial_drop_list = ['e' ]
17+ updated_drop_list = remove_constant_columns (df , initial_drop_list .copy (), verbose = 0 )
18+ self .assertCountEqual (updated_drop_list , ['e' , 'b' , 'd' ])
19+
20+ def test_remove_constant_columns_no_constants (self ):
21+ """Test that no columns are added when there are no constants."""
22+ df = pd .DataFrame ({
23+ 'a' : [1 , 2 , 3 ],
24+ 'b' : [4 , 5 , 6 ]
25+ })
26+ updated_drop_list = remove_constant_columns (df , [], verbose = 0 )
27+ self .assertEqual (updated_drop_list , [])
28+
29+ def test_remove_constant_columns_empty_df (self ):
30+ """Test with an empty DataFrame."""
31+ df = pd .DataFrame ()
32+ updated_drop_list = remove_constant_columns (df , [], verbose = 0 )
33+ self .assertEqual (updated_drop_list , [])
34+
35+ class TestRemoveConstantColumnsWithDebug (unittest .TestCase ):
36+
37+ def test_pandas_2d_constant_in_train (self ):
38+ """Test with a constant column in the training DataFrame."""
39+ X_train = pd .DataFrame ({'a' : [1 , 2 , 3 ], 'b' : [5 , 5 , 5 ]})
40+ X_test = pd .DataFrame ({'a' : [4 , 5 , 6 ], 'b' : [7 , 8 , 9 ]})
41+ X_test_orig = X_test .copy ()
42+
43+ train_out , test_out , orig_out = remove_constant_columns_with_debug (
44+ X_train , X_test , X_test_orig , verbosity = 0
45+ )
46+
47+ self .assertNotIn ('b' , train_out .columns )
48+ self .assertNotIn ('b' , test_out .columns )
49+ self .assertNotIn ('b' , orig_out .columns )
50+ self .assertIn ('a' , train_out .columns )
51+
52+ def test_pandas_2d_constant_in_test (self ):
53+ """Test that a column constant only in the test set is NOT removed."""
54+ X_train = pd .DataFrame ({'a' : [1 , 2 , 3 ], 'b' : [7 , 8 , 9 ]})
55+ X_test = pd .DataFrame ({'a' : [4 , 5 , 6 ], 'b' : [5 , 5 , 5 ]})
56+ X_test_orig = X_test .copy ()
57+
58+ train_out , test_out , orig_out = remove_constant_columns_with_debug (
59+ X_train , X_test , X_test_orig , verbosity = 0
60+ )
61+
62+ # 'b' should NOT be removed as it has variance in the training set.
63+ self .assertIn ('b' , train_out .columns )
64+ self .assertIn ('b' , test_out .columns )
65+ self .assertIn ('b' , orig_out .columns )
66+ self .assertIn ('a' , train_out .columns )
67+
68+ def test_numpy_2d (self ):
69+ """Test with 2D numpy arrays."""
70+ X_train = np .array ([[1 , 5 ], [2 , 5 ], [3 , 5 ]])
71+ X_test = np .array ([[4 , 7 ], [5 , 8 ], [6 , 9 ]])
72+ X_test_orig = X_test .copy ()
73+
74+ train_out , test_out , orig_out = remove_constant_columns_with_debug (
75+ X_train , X_test , X_test_orig , verbosity = 0
76+ )
77+
78+ self .assertEqual (train_out .shape [1 ], 1 )
79+ self .assertEqual (test_out .shape [1 ], 1 )
80+ self .assertEqual (orig_out .shape [1 ], 1 )
81+ self .assertTrue (np .array_equal (train_out , np .array ([[1 ], [2 ], [3 ]])))
82+
83+ def test_numpy_3d_time_series (self ):
84+ """Test with 3D numpy arrays for time series data."""
85+ # Shape: (samples, features, timesteps)
86+ X_train = np .array ([
87+ [[1 , 1 ], [5 , 5 ], [1 , 1 ]], # Sample 1: Feature 1 varies, Feature 2 is constant
88+ [[2 , 2 ], [5 , 5 ], [2 , 2 ]], # Sample 2
89+ ])
90+ X_test = np .array ([
91+ [[3 , 3 ], [9 , 9 ], [3 , 3 ]],
92+ ])
93+ X_test_orig = X_test .copy ()
94+
95+ train_out , test_out , orig_out = remove_constant_columns_with_debug (
96+ X_train , X_test , X_test_orig , verbosity = 0
97+ )
98+
99+ # Expecting feature 1 (index 0) and 2 (index 2) to be kept, feature 2 (index 1) to be dropped
100+ self .assertEqual (train_out .shape [1 ], 2 )
101+ self .assertEqual (test_out .shape [1 ], 2 )
102+ self .assertEqual (orig_out .shape [1 ], 2 )
103+
104+ if __name__ == '__main__' :
105+ unittest .main ()
0 commit comments