@@ -55,8 +55,10 @@ def _get_column_transformer(
5555 if encoding_and_scaling_model_type == "tree_like" :
5656 return ColumnTransformer (
5757 transformers = [
58- ("ohe_gender_encoder" , self ._ohe_gender_encoder (), ["gender" ]),
59- ("ohe_encode_country" , self ._ohe_encode_country (), ["country" ])
58+ ("ohe_gender_encoder" , self ._ohe_gender_encoder (),
59+ ["gender" ]),
60+ ("ohe_encode_country" , self ._ohe_encode_country (),
61+ ["country" ])
6062 ],
6163 remainder = "passthrough" ,
6264 force_int_remainder_cols = False ,
@@ -66,10 +68,20 @@ def _get_column_transformer(
6668 elif encoding_and_scaling_model_type == "other" :
6769 return ColumnTransformer (
6870 transformers = [
69- ("ohe_gender_encoder" , self ._ohe_gender_encoder (), ["gender" ]),
70- ("ohe_encode_country" , self ._ohe_encode_country (), ["country" ]),
71- ("credit_score_dist_scaler" , self ._credit_score_dist_scaler (), ["credit_score" ]),
72- ("estimated_salary_scaler" , self ._estimated_salary_scaler (), ["estimated_salary" ]),
71+ ("ohe_gender_encoder" , self ._ohe_gender_encoder (),
72+ ["gender" ]),
73+ ("ohe_encode_country" , self ._ohe_encode_country (),
74+ ["country" ]),
75+ (
76+ "credit_score_dist_scaler" ,
77+ self ._credit_score_dist_scaler (),
78+ ["credit_score" ]
79+ ),
80+ (
81+ "estimated_salary_scaler" ,
82+ self ._estimated_salary_scaler (),
83+ ["estimated_salary" ]
84+ ),
7385 ("age_scaler" , self ._age_scaler (), ["age" ]),
7486 ("balance_scaler" , self ._balance_scaler (), ["balance" ])
7587 ],
@@ -94,7 +106,9 @@ def transform(
94106 )
95107 if self ._config .encoding_and_scaling_model_type is not None :
96108 col_trfm : ColumnTransformer = self ._get_column_transformer (
97- encoding_and_scaling_model_type = self ._config .encoding_and_scaling_model_type
109+ encoding_and_scaling_model_type = (
110+ self ._config .encoding_and_scaling_model_type
111+ )
98112 )
99113 col_trfm .fit (X = X_train )
100114 X_train = pd .DataFrame (
0 commit comments