Skip to content

Commit bbdfd17

Browse files
Fix MLflow model logging by correctly calling the train_model method
1 parent 20c108c commit bbdfd17

1 file changed

Lines changed: 7 additions & 2 deletions

File tree

train_with_components.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,14 @@ def __init__(self, train_path, test_path):
236236
mlflow.log_metric("test_recall", model_trainer_artifact.test_metric_artifact.recallScore)
237237

238238
# Log model
239-
trained_model = model_trainer.train_model(
240-
*model_trainer.initiate_model_trainer().__dict__.values()
239+
# Get the training data
240+
train_arr = load_numpy_array_data(
241+
data_transformation_artifact.transformed_train_file_path
241242
)
243+
x_train, y_train = train_arr[:, :-1], train_arr[:, -1]
244+
245+
# Train a new model for MLflow logging
246+
trained_model = model_trainer.train_model(x_train, y_train)
242247

243248
mlflow.sklearn.log_model(
244249
sk_model=trained_model,

0 commit comments

Comments
 (0)