Skip to content

Commit 98b8403

Browse files
committed
minor column drop fix
1 parent f87218b commit 98b8403

1 file changed

Lines changed: 13 additions & 10 deletions

File tree

ml_grid/pipeline/data.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,9 @@ def __init__(
172172
X=self.df, drop_list=self.drop_list, verbose=self.verbose)
173173

174174
self.final_column_list = [
175-
self.X
176-
for self.X in self.pertubation_columns
177-
if (self.X not in self.drop_list and self.X in self.df.columns)
175+
col
176+
for col in self.pertubation_columns
177+
if (col not in self.drop_list and col in self.df.columns)
178178
]
179179
# Add safety mechanism to retain minimum features
180180
min_required_features = 5 # Set your minimum threshold
@@ -194,17 +194,20 @@ def __init__(
194194

195195
# Update final columns and drop list
196196
self.final_column_list = safety_columns
197-
self.drop_list = [col for col in self.drop_list
198-
if col not in self.final_column_list]
197+
# Also update the main drop list to prevent re-pruning
198+
self.drop_list = [col for col in self.drop_list if col not in self.final_column_list]
199199

200200
print(f"Retaining minimum features: {self.final_column_list}")
201+
202+
# Re-filter final_column_list to be absolutely sure
203+
self.final_column_list = [col for col in self.pertubation_columns if col not in self.drop_list and col in self.df.columns]
204+
201205

202206
# Add two random features if list still empty
203207
if not self.final_column_list:
204208
print("Warning no feature columns retained, selecting two at random")
205-
final_column_list = []
206-
final_column_list.append(random.choice(self.orignal_feature_names))
207-
final_column_list.append(random.choice(self.orignal_feature_names))
209+
self.final_column_list.append(random.choice(self.orignal_feature_names))
210+
self.final_column_list.append(random.choice(self.orignal_feature_names))
208211

209212
# Ensure we still have at least 1 feature
210213
if not self.final_column_list:
@@ -324,9 +327,9 @@ def __init__(
324327
)
325328
try:
326329

330+
fim = feature_importance_methods()
327331
self.X_train, self.X_test, self.X_test_orig = (
328-
feature_importance_methods.handle_feature_importance_methods(
329-
self,
332+
fim.handle_feature_importance_methods(
330333
target_n_features_eval,
331334
X_train=self.X_train,
332335
X_test=self.X_test,

0 commit comments

Comments
 (0)