Skip to content

Commit f418b38

Browse files
committed
reverted to pandas forlgm
1 parent 0669472 commit f418b38

1 file changed

Lines changed: 5 additions & 4 deletions

File tree

ml_grid/pipeline/grid_search_cross_validate.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,11 +330,10 @@ def __init__(
330330
if self.global_parameters.verbose >= 3:
331331
print("Fitting final model")
332332
#current_algorithm = grid.best_estimator_
333-
# Use numpy arrays for fitting the final model and for cross-validation.
334-
X_train_final_np = self.X_train.values
333+
# Pass the DataFrame for the final fit to support models that need column names (e.g., LightGBM wrapper).
334+
# For cross-validation, we will use numpy arrays for performance and compatibility.
335335
y_train_values = self.y_train.values
336-
337-
current_algorithm.fit(X_train_final_np, y_train_values)
336+
current_algorithm.fit(self.X_train, y_train_values)
338337

339338
metric_list = self.metric_list
340339

@@ -372,6 +371,7 @@ def __init__(
372371

373372
try:
374373
# Perform the cross-validation
374+
X_train_final_np = self.X_train.values
375375
scores = cross_validate(
376376
current_algorithm,
377377
X_train_final_np,
@@ -390,6 +390,7 @@ def __init__(
390390
current_algorithm.set_params(tree_method='hist')
391391

392392
try:
393+
X_train_final_np = self.X_train.values
393394
scores = cross_validate(
394395
current_algorithm,
395396
X_train_final_np,

0 commit comments

Comments
 (0)