Skip to content

Commit 9fec6fe

Browse files
committed
store column names when undersample
1 parent 1135845 commit 9fec6fe

1 file changed

Lines changed: 14 additions & 2 deletions

File tree

ml_grid/pipeline/data_train_test_split.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,15 @@ def get_data_split(X, y, local_param_dict):
6666
# print((y.shape))
6767
# print(X.shape)
6868

69+
# Store original column names and y name to reconstruct DataFrame after resampling
70+
original_columns = X.columns
71+
y_name = y.name
72+
6973
# Undersample data
7074
rus = RandomUnderSampler(random_state=0)
71-
X, y = rus.fit_resample(X, y)
75+
X_res, y_res = rus.fit_resample(X, y)
76+
X = pd.DataFrame(X_res, columns=original_columns)
77+
y = pd.Series(y_res, name=y_name)
7278

7379
# Split into training and testing sets
7480
X_train_orig, X_test_orig, y_train_orig, y_test_orig = train_test_split(
@@ -89,10 +95,16 @@ def get_data_split(X, y, local_param_dict):
8995
X, y, test_size=0.25, random_state=1
9096
)
9197

98+
# Store original column names to reconstruct DataFrame after resampling
99+
original_columns = X_train_orig.columns
100+
y_name = y_train_orig.name
101+
92102
# Oversample training set
93103
sampling_strategy = 1
94104
ros = RandomOverSampler(sampling_strategy=sampling_strategy)
95-
X_train_orig, y_train_orig = ros.fit_resample(X_train_orig, y_train_orig)
105+
X_train_orig_res, y_train_orig_res = ros.fit_resample(X_train_orig, y_train_orig)
106+
X_train_orig = pd.DataFrame(X_train_orig_res, columns=original_columns)
107+
y_train_orig = pd.Series(y_train_orig_res, name=y_name)
96108
print(y_train_orig.value_counts())
97109

98110
# Split training set into final training and validation sets

0 commit comments

Comments
 (0)