Skip to content

Commit 3249996

Browse files
author
SamoraHunter
committed
feature selection column validation fix
1 parent 6b907b9 commit 3249996

1 file changed

Lines changed: 30 additions & 14 deletions

File tree

ml_grid/pipeline/data_feature_importance_methods.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,12 @@
55

66
from ml_grid.pipeline.data_feature_methods import feature_methods
77

8-
# rename this class
9-
10-
118
class feature_importance_methods:
129
"""A class to handle feature selection using different importance methods."""
1310

1411
def __init__(self) -> None:
1512
"""Initializes the feature_importance_methods class."""
16-
pass
13+
self.feature_method = "None"
1714

1815
def handle_feature_importance_methods(
1916
self,
@@ -50,6 +47,7 @@ def handle_feature_importance_methods(
5047
"""
5148

5249
logger = logging.getLogger("ml_grid")
50+
5351
# Work with copies to avoid modifying the original DataFrames in the calling scope
5452
X_train_copy = X_train.copy()
5553
X_test_copy = X_test.copy()
@@ -58,33 +56,51 @@ def handle_feature_importance_methods(
5856
self.feature_method = ml_grid_object.local_param_dict.get(
5957
"feature_selection_method"
6058
)
59+
60+
# Default to all features initially
61+
features = list(X_train_copy.columns)
6162

6263
if self.feature_method == "anova" or self.feature_method is None:
6364
logger.info("feature_method ANOVA")
6465
fm = feature_methods()
65-
# The data pipeline now guarantees a clean index, so no reset is needed here.
6666
features = fm.getNfeaturesANOVAF(
6767
n=target_n_features, X_train=X_train_copy, y_train=y_train
6868
)
6969

7070
elif self.feature_method == "markov_blanket":
7171
logger.info("feature method Markov")
7272
fm = feature_methods()
73-
# The data pipeline now guarantees a clean index, so no reset is needed here.
7473
features = fm.getNFeaturesMarkovBlanket(
7574
n=target_n_features, X_train=X_train_copy, y_train=y_train
7675
)
7776

7877
logger.info(f"target_n_features: {target_n_features}")
79-
logger.info(f"Selected features: {features}")
78+
79+
# --- Column Validation ---
80+
# Filter the requested 'features' to ensure they actually exist in the DataFrame.
81+
# This handles cases where selectors return indices, 'ColumnX' names, or
82+
# names that were dropped/renamed in previous pipeline steps.
83+
84+
valid_features = [f for f in features if f in X_train_copy.columns]
85+
86+
if len(valid_features) == 0:
87+
logger.warning(
88+
f"Feature selection ({self.feature_method}) returned 0 valid features. "
89+
f"Requested examples: {features[:5] if features else 'None'}. "
90+
"Falling back to ALL original features to prevent crash."
91+
)
92+
valid_features = list(X_train_copy.columns)
93+
elif len(valid_features) < len(features):
94+
logger.warning(
95+
f"{len(features) - len(valid_features)} selected features were not found in X_train columns. Dropped invalid keys."
96+
)
8097

81-
# CRITICAL FIX: Apply feature selection to the X_train that was passed in,
82-
# which has already been cleaned of post-split constant columns.
83-
X_train_out = X_train_copy[features]
98+
logger.info(f"Final selected features ({len(valid_features)}): {valid_features}")
8499

85-
# Apply the same feature selection to the test sets
86-
X_test_out = X_test.copy()[features]
87-
X_test_orig_out = X_test_orig.copy()[features]
100+
# Apply the validated selection
101+
X_train_out = X_train_copy[valid_features]
102+
X_test_out = X_test_copy[valid_features]
103+
X_test_orig_out = X_test_orig_copy[valid_features]
88104

89105
# The y series do not need to be modified, as they are already aligned.
90-
return X_train_out, y_train, X_test_out, y_test, X_test_orig_out
106+
return X_train_out, y_train, X_test_out, y_test, X_test_orig_out

0 commit comments

Comments
 (0)