55
66from ml_grid .pipeline .data_feature_methods import feature_methods
77
8+
89class feature_importance_methods :
910 """A class to handle feature selection using different importance methods."""
1011
@@ -47,7 +48,7 @@ def handle_feature_importance_methods(
4748 """
4849
4950 logger = logging .getLogger ("ml_grid" )
50-
51+
5152 # Work with copies to avoid modifying the original DataFrames in the calling scope
5253 X_train_copy = X_train .copy ()
5354 X_test_copy = X_test .copy ()
@@ -56,7 +57,7 @@ def handle_feature_importance_methods(
5657 self .feature_method = ml_grid_object .local_param_dict .get (
5758 "feature_selection_method"
5859 )
59-
60+
6061 # Default to all features initially
6162 features = list (X_train_copy .columns )
6263
@@ -75,14 +76,14 @@ def handle_feature_importance_methods(
7576 )
7677
7778 logger .info (f"target_n_features: { target_n_features } " )
78-
79+
7980 # --- Column Validation ---
8081 # Filter the requested 'features' to ensure they actually exist in the DataFrame.
81- # This handles cases where selectors return indices, 'ColumnX' names, or
82+ # This handles cases where selectors return indices, 'ColumnX' names, or
8283 # names that were dropped/renamed in previous pipeline steps.
83-
84+
8485 valid_features = [f for f in features if f in X_train_copy .columns ]
85-
86+
8687 if len (valid_features ) == 0 :
8788 logger .warning (
8889 f"Feature selection ({ self .feature_method } ) returned 0 valid features. "
@@ -91,16 +92,18 @@ def handle_feature_importance_methods(
9192 )
9293 valid_features = list (X_train_copy .columns )
9394 elif len (valid_features ) < len (features ):
94- logger .warning (
95- f"{ len (features ) - len (valid_features )} selected features were not found in X_train columns. Dropped invalid keys."
96- )
95+ logger .warning (
96+ f"{ len (features ) - len (valid_features )} selected features were not found in X_train columns. Dropped invalid keys."
97+ )
9798
98- logger .info (f"Final selected features ({ len (valid_features )} ): { valid_features } " )
99+ logger .info (
100+ f"Final selected features ({ len (valid_features )} ): { valid_features } "
101+ )
99102
100- # Apply the validated selection
103+ # Apply the validated selection
101104 X_train_out = X_train_copy [valid_features ]
102105 X_test_out = X_test_copy [valid_features ]
103106 X_test_orig_out = X_test_orig_copy [valid_features ]
104107
105108 # The y series do not need to be modified, as they are already aligned.
106- return X_train_out , y_train , X_test_out , y_test , X_test_orig_out
109+ return X_train_out , y_train , X_test_out , y_test , X_test_orig_out
0 commit comments