Skip to content

Commit 32497c2

Browse files
committed
2 parents b1530bb + 52b4349 commit 32497c2

1 file changed

Lines changed: 12 additions & 1 deletion

File tree

Baselines/gp_user2Lquestion.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
parser.add_argument("--redo_pretraining", action='store_true')
5151
parser.add_argument("--cache_file_path", default="../cache")
5252
parser.add_argument("--log_mu_sigma", action='store_true', help="If true all mu and sigma outputs of the gp will be saved for all candidates")
53+
parser.add_argument("--only_pretraining", action="store_true")
5354

5455
args = parser.parse_args()
5556

@@ -166,6 +167,11 @@
166167
model.kern.kern_list[0].variance = float(args.k_var)
167168
model.kern.kern_list[0].lengthscales = np.ones(gp_input.shape[1]) * float(args.k_len)
168169
model.kern.kern_list[1].variace = 1.0
170+
elif args.kernel == "matern12":
171+
model = GPflow.sgpr.SGPR(gp_input, observed_labels, GPflow.kernels.Add([GPflow.kernels.Matern12(gp_input.shape[1], ARD=True), GPflow.kernels.White(gp_input.shape[1])]), Z=Z1)
172+
model.kern.kern_list[0].variance = float(args.k_var)
173+
model.kern.kern_list[0].lengthscales = np.ones(gp_input.shape[1]) * float(args.k_len)
174+
model.kern.kern_list[1].variace = 1e-5
169175
else:
170176
raise ValueError("Chosen kernel is not implemented")
171177
model.likelihood.variance = 0.001
@@ -230,7 +236,7 @@
230236

231237
if args.verbose_opt == 'all':
232238
print("n_new_points right before condition: {}".format(n_new_points))
233-
if n_new_points > 0:
239+
if n_new_points > 0 and not args.only_pretraining:
234240
new_gp_input = persistent_scaler.transform(training_set_for_gp[-n_new_points:])
235241
new_observed_labels = observed_labels[-n_new_points:]
236242

@@ -259,6 +265,11 @@
259265
new_model.kern.kern_list[0].variance = model.kern.kern_list[0].variance.value
260266
new_model.kern.kern_list[0].lengthscales = model.kern.kern_list[0].lengthscales.value
261267
new_model.kern.kern_list[1].variance = model.kern.kern_list[1].variance.value
268+
elif args.kernel == "matern12":
269+
new_model = osgpr.OSGPR_VFE(new_gp_input, new_observed_labels, GPflow.kernels.Add([GPflow.kernels.Matern12(gp_input.shape[1], ARD=True), GPflow.kernels.White(gp_input.shape[1])]), mu, Su, Kaa, Zopt, Zinit)
270+
new_model.kern.kern_list[0].variance = model.kern.kern_list[0].variance.value
271+
new_model.kern.kern_list[0].lengthscales = model.kern.kern_list[0].lengthscales.value
272+
new_model.kern.kern_list[1].variance = model.kern.kern_list[1].variance.value
262273

263274
new_model.likelihood.variance = model.likelihood.variance.value
264275
model = new_model

0 commit comments

Comments
 (0)