|
1 | | -from typing import Optional |
| 1 | +from typing import Final |
2 | 2 |
|
3 | 3 | import pandas as pd |
4 | 4 | from sklearn.model_selection import train_test_split |
|
9 | 9 |
|
10 | 10 | from moddata.src.constants import EncodingAndScalingModelType |
11 | 11 | from moddata.sklearn_extensions.log_standard_scaler import LogStandardScaler |
| 12 | +from moddata.src.config import BankchurnTransformerConfig |
12 | 13 |
|
13 | 14 |
|
14 | 15 | class BankchurnTransformer: |
15 | 16 |
|
16 | | - def __init__( |
17 | | - self, |
18 | | - train_size: float | int, |
19 | | - random_state: Optional[int] = None, |
20 | | - encoding_and_scaling_model_type: Optional[EncodingAndScalingModelType] = None, |
21 | | - ): |
22 | | - """ |
23 | | -
|
24 | | - Args: |
25 | | - train_size: parameter passed to the train_test_split method |
26 | | - used to create train and test datasets |
27 | | - random_state: analogous to train_size |
28 | | - encoding_and_scaling_model_type: Literal, defines what type of |
29 | | - model data should be prepared for |
30 | | - """ |
31 | | - self._train_size: float | int = train_size |
32 | | - self._random_state: Optional[int] = random_state |
33 | | - self._encoding_and_scaling_model_type: Optional[EncodingAndScalingModelType] = ( |
34 | | - encoding_and_scaling_model_type) |
| 17 | + def __init__(self, config: BankchurnTransformerConfig): |
| 18 | + self._config: Final[BankchurnTransformerConfig] = config |
35 | 19 |
|
36 | 20 | @staticmethod |
37 | 21 | def _ohe_gender_encoder() -> OneHotEncoder: |
@@ -101,14 +85,13 @@ def transform( |
101 | 85 | ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]: |
102 | 86 | X, y = data |
103 | 87 | X_train, X_test, y_train, y_test = train_test_split( |
104 | | - X, |
105 | | - y, |
106 | | - train_size=self._train_size, |
107 | | - random_state=self._random_state |
| 88 | + X, y, |
| 89 | + train_size=self._config.train_size, |
| 90 | + random_state=self._config.random_state |
108 | 91 | ) |
109 | | - if self._encoding_and_scaling_model_type is not None: |
| 92 | + if self._config.encoding_and_scaling_model_type is not None: |
110 | 93 | col_trfm: ColumnTransformer = self._get_column_transformer( |
111 | | - encoding_and_scaling_model_type=self._encoding_and_scaling_model_type |
| 94 | + encoding_and_scaling_model_type=self._config.encoding_and_scaling_model_type |
112 | 95 | ) |
113 | 96 | X_train, y_train = col_trfm.fit_transform(X=X_train, y=y_train) |
114 | 97 | X_test, y_test = col_trfm.transform(X=X_test, y=y_test) |
|
0 commit comments