Skip to content

Commit e1d1fd7

Browse files
committed
created separate config file for managing parameters of the BankchurnTransformer
1 parent cc3d6c3 commit e1d1fd7

3 files changed

Lines changed: 39 additions & 27 deletions

File tree

moddata/pipeline/bankchurn_pipeline.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from moddata.extractor.bankchurn_extractor import BankchurnExtractor
66
from moddata.transformer.bankchurn_transformer import BankchurnTransformer
7+
from moddata.src.constants import EncodingAndScalingModelType
78

89

910
class BankchurnPipeline:
@@ -12,12 +13,15 @@ def __init__(
1213
self,
1314
train_size: float | int,
1415
random_state: Optional[int] = None,
16+
encoding_and_scaling_model_type: Optional[EncodingAndScalingModelType] = None
1517
):
1618
self._random_state: Optional[int] = random_state
19+
self._encoding_and_scaling_model_type: Optional[EncodingAndScalingModelType] = encoding_and_scaling_model_type
1720
self._transformer: Final[BankchurnTransformer] = (
1821
BankchurnTransformer(
1922
train_size=train_size,
20-
random_state=random_state
23+
random_state=random_state,
24+
encoding_and_scaling_model_type=self._encoding_and_scaling_model_type
2125
)
2226
)
2327

moddata/src/config.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Stores pydantic-style configuration classes.
2+
3+
These are mainly bundles of values used not to clutter __init__-s
4+
"""
5+
6+
from typing import Optional
7+
8+
from pydantic import BaseModel, ConfigDict
9+
10+
from moddata.src.constants import EncodingAndScalingModelType
11+
12+
13+
class BankchurnTransformerConfig(BaseModel):
14+
"""
15+
train_size: parameter passed to the train_test_split method
16+
used to create train and test datasets
17+
random_state: analogous to train_size
18+
encoding_and_scaling_model_type: Literal, defines what type of
19+
model data should be prepared for
20+
"""
21+
model_config = ConfigDict(arbitrary_types_allowed=True)
22+
23+
train_size: float | int
24+
random_state: Optional[int] = None,
25+
encoding_and_scaling_model_type: Optional[EncodingAndScalingModelType] = None

moddata/transformer/bankchurn_transformer.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Final
22

33
import pandas as pd
44
from sklearn.model_selection import train_test_split
@@ -9,29 +9,13 @@
99

1010
from moddata.src.constants import EncodingAndScalingModelType
1111
from moddata.sklearn_extensions.log_standard_scaler import LogStandardScaler
12+
from moddata.src.config import BankchurnTransformerConfig
1213

1314

1415
class BankchurnTransformer:
1516

16-
def __init__(
17-
self,
18-
train_size: float | int,
19-
random_state: Optional[int] = None,
20-
encoding_and_scaling_model_type: Optional[EncodingAndScalingModelType] = None,
21-
):
22-
"""
23-
24-
Args:
25-
train_size: parameter passed to the train_test_split method
26-
used to create train and test datasets
27-
random_state: analogous to train_size
28-
encoding_and_scaling_model_type: Literal, defines what type of
29-
model data should be prepared for
30-
"""
31-
self._train_size: float | int = train_size
32-
self._random_state: Optional[int] = random_state
33-
self._encoding_and_scaling_model_type: Optional[EncodingAndScalingModelType] = (
34-
encoding_and_scaling_model_type)
17+
def __init__(self, config: BankchurnTransformerConfig):
18+
self._config: Final[BankchurnTransformerConfig] = config
3519

3620
@staticmethod
3721
def _ohe_gender_encoder() -> OneHotEncoder:
@@ -101,14 +85,13 @@ def transform(
10185
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
10286
X, y = data
10387
X_train, X_test, y_train, y_test = train_test_split(
104-
X,
105-
y,
106-
train_size=self._train_size,
107-
random_state=self._random_state
88+
X, y,
89+
train_size=self._config.train_size,
90+
random_state=self._config.random_state
10891
)
109-
if self._encoding_and_scaling_model_type is not None:
92+
if self._config.encoding_and_scaling_model_type is not None:
11093
col_trfm: ColumnTransformer = self._get_column_transformer(
111-
encoding_and_scaling_model_type=self._encoding_and_scaling_model_type
94+
encoding_and_scaling_model_type=self._config.encoding_and_scaling_model_type
11295
)
11396
X_train, y_train = col_trfm.fit_transform(X=X_train, y=y_train)
11497
X_test, y_test = col_trfm.transform(X=X_test, y=y_test)

0 commit comments

Comments
 (0)