Skip to content

Commit cc840c4

Browse files
committed
Added config for the BankchurnPipeline and created aliases for tuples of DataFrames
1 parent e1d1fd7 commit cc840c4

6 files changed

Lines changed: 28 additions & 24 deletions

File tree

moddata/extractor/bankchurn_extractor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import pandas as pd
22

33
from moddata import load_data
4-
4+
from moddata.src.constants import XyDataFrames
55

66
class BankchurnExtractor:
77

8-
def extract(self) -> tuple[pd.DataFrame, pd.DataFrame]:
8+
def extract(self) -> XyDataFrames:
99
data: pd.DataFrame = load_data(dataset="bankchurn")
1010
x: pd.DataFrame = data.loc[:, data.columns != "churn"]
1111
x = x.drop(columns=["customer_id"])
Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,19 @@
1-
from typing import Final, Optional
2-
3-
import pandas as pd
1+
from typing import Final
42

53
from moddata.extractor.bankchurn_extractor import BankchurnExtractor
64
from moddata.transformer.bankchurn_transformer import BankchurnTransformer
7-
from moddata.src.constants import EncodingAndScalingModelType
5+
from moddata.src.constants import TrainTestXyDataFrames
6+
from moddata.src.config import BankchurnPipelineConfig
87

98

109
class BankchurnPipeline:
1110

12-
def __init__(
13-
self,
14-
train_size: float | int,
15-
random_state: Optional[int] = None,
16-
encoding_and_scaling_model_type: Optional[EncodingAndScalingModelType] = None
17-
):
18-
self._random_state: Optional[int] = random_state
19-
self._encoding_and_scaling_model_type: Optional[EncodingAndScalingModelType] = encoding_and_scaling_model_type
11+
def __init__(self, config: BankchurnPipelineConfig):
12+
self._config: Final[BankchurnPipelineConfig] = config
2013
self._transformer: Final[BankchurnTransformer] = (
21-
BankchurnTransformer(
22-
train_size=train_size,
23-
random_state=random_state,
24-
encoding_and_scaling_model_type=self._encoding_and_scaling_model_type
25-
)
14+
BankchurnTransformer(config=self._config)
2615
)
2716

28-
def run(self) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
17+
def run(self) -> TrainTestXyDataFrames:
2918
x, y = BankchurnExtractor().extract()
3019
return self._transformer.transform(data=(x, y))

moddata/src/config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,17 @@
33
These are mainly bundles of values used not to clutter __init__-s
44
"""
55

6-
from typing import Optional
6+
from typing import Optional, TypeAlias
77

88
from pydantic import BaseModel, ConfigDict
99

1010
from moddata.src.constants import EncodingAndScalingModelType
1111

12+
__all__ = [
13+
"BankchurnTransformerConfig",
14+
"BankchurnPipelineConfig"
15+
]
16+
1217

1318
class BankchurnTransformerConfig(BaseModel):
1419
"""
@@ -23,3 +28,6 @@ class BankchurnTransformerConfig(BaseModel):
2328
train_size: float | int
2429
random_state: Optional[int] = None,
2530
encoding_and_scaling_model_type: Optional[EncodingAndScalingModelType] = None
31+
32+
33+
BankchurnPipelineConfig: TypeAlias = BankchurnTransformerConfig

moddata/src/constants.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,15 @@
77

88
from typing import TypeAlias, Literal
99

10+
import pandas as pd
1011

1112
EncodingAndScalingModelType: TypeAlias = Literal[
1213
"tree_like",
1314
"other"
14-
]
15+
]
16+
17+
XyDataFrames: TypeAlias = tuple[pd.DataFrame, pd.DataFrame]
18+
19+
TrainTestXyDataFrames: TypeAlias = (
20+
tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]
21+
)

moddata/transformer/bankchurn_transformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from moddata.src.constants import EncodingAndScalingModelType
1111
from moddata.sklearn_extensions.log_standard_scaler import LogStandardScaler
1212
from moddata.src.config import BankchurnTransformerConfig
13+
from moddata.src.constants import TrainTestXyDataFrames
1314

1415

1516
class BankchurnTransformer:
@@ -82,7 +83,7 @@ def _get_column_transformer(
8283
def transform(
8384
self,
8485
data: tuple[pd.DataFrame, pd.DataFrame]
85-
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
86+
) -> TrainTestXyDataFrames:
8687
X, y = data
8788
X_train, X_test, y_train, y_test = train_test_split(
8889
X, y,

tests/sklearn_extensions/test_log_standard_scaler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def test_log_standard_scaler_with_shift(make_log_normal_array):
3030
def test_log_standard_scaler_with_base(make_log_normal_array):
3131
lss: LogStandardScaler = LogStandardScaler(log_base=2)
3232
X_trfmd: np.ndarray = lss.fit_transform(X=make_log_normal_array)
33-
# print(f"{X_trfmd=}")
3433
assert round(float(X_trfmd[0, 0]), 8) == -1.41835861
3534
assert X_trfmd.shape == (5, 1)
3635
assert round(float(X_trfmd[-1, 0]), 8) == -0.64893433

0 commit comments

Comments
 (0)