Skip to content

Commit 4c889e8

Browse files
committed
corrected impleentation of __init__ of the custom log scale transformer that has been inconsistent with SKLearn API
1 parent b0785e6 commit 4c889e8

4 files changed

Lines changed: 56 additions & 10 deletions

File tree

moddata/sklearn_extensions/log_standard_scaler.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Final
22

33
import numpy as np
4+
import pandas as pd
45
from sklearn.preprocessing import StandardScaler
56
from sklearn.base import BaseEstimator, TransformerMixin
67

@@ -31,9 +32,10 @@ def __init__(
3132
before log-transformation
3233
log_base: base of the log-transformation
3334
"""
34-
self._shift: Final[float] = shift
35-
self._standard_scaler: StandardScaler = StandardScaler()
36-
self._log_base: float | int = log_base
35+
self.shift: Final[float] = shift
36+
self.log_base: float | int = log_base
37+
38+
self.feature_names_in_ = None
3739

3840
@staticmethod
3941
def _validate_shift(shift: float) -> None:
@@ -42,16 +44,34 @@ def _validate_shift(shift: float) -> None:
4244
f"Received value {shift=}")
4345

4446
def _log_transform(self, X):
45-
return np.emath.logn(n=self._log_base, x=X + self._shift)
47+
return np.emath.logn(n=self.log_base, x=X + self.shift)
4648

47-
def fit(self, X, y=None):
49+
def fit(self, X: pd.DataFrame):
50+
self._standard_scaler: StandardScaler = StandardScaler() # noqa
51+
if not isinstance(X, pd.DataFrame):
52+
raise ValueError("This estimator only accepts pd.DataFrame input!")
53+
self.feature_names_in_ = list(X.columns)
4854
X_log = self._log_transform(X=X)
4955
self._standard_scaler.fit(X=X_log)
5056
return self
5157

52-
def transform(self, X):
58+
def transform(self, X: pd.DataFrame):
59+
if not isinstance(X, pd.DataFrame):
60+
raise ValueError("This transformer only accepts pd.DataFrame input!")
5361
X_log = self._log_transform(X=X)
54-
return self._standard_scaler.transform(X=X_log + self._shift)
62+
return self._standard_scaler.transform(X=X_log + self.shift)
5563

5664
def fit_transform(self, X, y=None, **fit_params):
5765
return self.fit(X=X).transform(X=X)
66+
67+
def get_feature_names_out(self, input_features=None):
68+
if input_features is None:
69+
if self.feature_names_in_ is not None:
70+
input_features = self.feature_names_in_
71+
else:
72+
raise ValueError("No input features provided and none "
73+
"were stored during fit! ")
74+
return [
75+
f"log_base_{self.log_base:0.2f}_shift_{self.shift:0.2f}_{feature}"
76+
for feature in input_features
77+
]

moddata/transformer/bankchurn_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def _get_column_transformer(
6969
("ohe_gender_encoder", self._ohe_gender_encoder(), ["gender"]),
7070
("ohe_encode_country", self._ohe_encode_country(), ["country"]),
7171
("credit_score_dist_scaler", self._credit_score_dist_scaler(), ["credit_score"]),
72-
("estimated_salary_scaler", self._estimated_salary_scaler(), ["estimated_salary_scaler"]),
72+
("estimated_salary_scaler", self._estimated_salary_scaler(), ["estimated_salary"]),
7373
("age_scaler", self._age_scaler(), ["age"]),
7474
("balance_scaler", self._balance_scaler(), ["balance"])
7575
],

tests/pipeline/test_bankchurn_pipeline.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import numpy as np
2+
import pandas as pd
23

34
from moddata.pipeline.bankchurn_pipeline import BankchurnPipeline
45
from moddata.src.config import BankchurnPipelineConfig
56

67

7-
def test_bankchurn_pipeline_tree_like():
8+
def test_bankchurn_pipeline_tree_like_model():
89
X_train, X_test, y_train, y_test = BankchurnPipeline(
910
config=BankchurnPipelineConfig(
1011
random_state=12345,
@@ -21,4 +22,23 @@ def test_bankchurn_pipeline_tree_like():
2122
assert np.all(np.array(y_test.index[:3]) == np.array([7867, 1402, 8606]))
2223

2324

24-
test_bankchurn_pipeline_tree_like
25+
def test_bankchurn_pipeline_other_model():
26+
X_train, X_test, y_train, y_test = BankchurnPipeline(
27+
config=BankchurnPipelineConfig(
28+
random_state=12345,
29+
train_size=0.8,
30+
encoding_and_scaling_model_type="other"
31+
)
32+
).run()
33+
34+
assert X_train.shape == (8_000, 11)
35+
assert X_test.shape == (2_000, 11)
36+
assert y_train.shape == (8_000, 1)
37+
assert y_test.shape == (2_000, 1)
38+
39+
assert isinstance(X_train, pd.DataFrame)
40+
assert isinstance(X_test, pd.DataFrame)
41+
assert isinstance(y_train, pd.DataFrame)
42+
assert isinstance(y_test, pd.DataFrame)
43+
44+
assert np.all(np.array(y_test.index[:3]) == np.array([7867, 1402, 8606]))

tests/sklearn_extensions/test_log_standard_scaler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pandas as pd
23
from pytest import fixture
34

45

@@ -41,3 +42,8 @@ def test_log_standard_scaler_with_shift_and_base(make_log_normal_array):
4142
assert round(float(X_trfmd[0, 0]), 8) == 7.12454284
4243
assert X_trfmd.shape == (5, 1)
4344
assert round(float(X_trfmd[-1, 0]), 8) == 7.1258186
45+
46+
47+
def test_use_of_log_standard_scaler_in_column_transformer(make_log_normal_array):
48+
data: pd.DataFrame = pd.DataFrame(data={"X": make_log_normal_array})
49+
lss: LogStandardScaler = LogStandardScaler(log_base=2, shift=20)

0 commit comments

Comments
 (0)