Skip to content

Commit e200638

Browse files
committed
2 parents d96892d + 7faedc9 commit e200638

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

Baselines/gp_user2Lquestion.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from sklearn.gaussian_process import GaussianProcessRegressor
1717
from sklearn.gaussian_process.kernels import DotProduct, RBF, Matern, WhiteKernel
18-
from sklearn.preprocessing import normalize, StandardScaler
18+
from sklearn.preprocessing import normalize, StandardScaler, MinMaxScaler
1919

2020
from gp_utils import *
2121

@@ -43,6 +43,7 @@
4343
# e.g. --only_use_features "votes_sd affinity_sum tag_popularity votes_mean question_age"
4444
parser.add_argument("--beta", default=0.4, type=float, metavar="b",
4545
help="beta parameter for exploration (0=no exploration)")
46+
parser.add_argument("--scaler", default="standard", help="minmax or standard (for normalization)")
4647

4748
parser.add_argument("--sum_file_path", default="../cache/gp/runs/")
4849
parser.add_argument("--save_every_n", default=1000, type=int)
@@ -148,7 +149,10 @@
148149

149150
#With osgpr we pretrain immediately
150151
if model_choice == "osgpr":
151-
persistent_scaler = StandardScaler()
152+
if args.scaler=="minmax":
153+
persistent_scaler = MinMaxScaler()
154+
else:
155+
persistent_scaler = StandardScaler()
152156
gp_input = persistent_scaler.fit_transform(training_set_for_gp)
153157
Z1 = gp_input[np.random.permutation(gp_input.shape[0])[0:M_points], :]
154158
if args.kernel == "linear":

0 commit comments

Comments
 (0)