@@ -66,9 +66,15 @@ def get_data_split(X, y, local_param_dict):
6666 # print((y.shape))
6767 # print(X.shape)
6868
69+ # Store original column names and y name to reconstruct DataFrame after resampling
70+ original_columns = X .columns
71+ y_name = y .name
72+
6973 # Undersample data
7074 rus = RandomUnderSampler (random_state = 0 )
71- X , y = rus .fit_resample (X , y )
75+ X_res , y_res = rus .fit_resample (X , y )
76+ X = pd .DataFrame (X_res , columns = original_columns )
77+ y = pd .Series (y_res , name = y_name )
7278
7379 # Split into training and testing sets
7480 X_train_orig , X_test_orig , y_train_orig , y_test_orig = train_test_split (
@@ -89,10 +95,16 @@ def get_data_split(X, y, local_param_dict):
8995 X , y , test_size = 0.25 , random_state = 1
9096 )
9197
98+ # Store original column names to reconstruct DataFrame after resampling
99+ original_columns = X_train_orig .columns
100+ y_name = y_train_orig .name
101+
92102 # Oversample training set
93103 sampling_strategy = 1
94104 ros = RandomOverSampler (sampling_strategy = sampling_strategy )
95- X_train_orig , y_train_orig = ros .fit_resample (X_train_orig , y_train_orig )
105+ X_train_orig_res , y_train_orig_res = ros .fit_resample (X_train_orig , y_train_orig )
106+ X_train_orig = pd .DataFrame (X_train_orig_res , columns = original_columns )
107+ y_train_orig = pd .Series (y_train_orig_res , name = y_name )
96108 print (y_train_orig .value_counts ())
97109
98110 # Split training set into final training and validation sets
0 commit comments