Skip to content

Commit 52b4349

Browse files
committed
matern and pre-only
1 parent 0bea5c0 commit 52b4349

1 file changed

Lines changed: 12 additions & 1 deletion

File tree

Baselines/gp_user2Lquestion.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
parser.add_argument("--redo_pretraining", action='store_true')
5050
parser.add_argument("--cache_file_path", default="../cache")
5151
parser.add_argument("--log_mu_sigma", action='store_true', help="If true all mu and sigma outputs of the gp will be saved for all candidates")
52+
parser.add_argument("--only_pretraining", action="store_true")
5253

5354
args = parser.parse_args()
5455

@@ -162,6 +163,11 @@
162163
model.kern.kern_list[0].variance = float(args.k_var)
163164
model.kern.kern_list[0].lengthscales = np.ones(gp_input.shape[1]) * float(args.k_len)
164165
model.kern.kern_list[1].variace = 1.0
166+
elif args.kernel == "matern12":
167+
model = GPflow.sgpr.SGPR(gp_input, observed_labels, GPflow.kernels.Add([GPflow.kernels.Matern12(gp_input.shape[1], ARD=True), GPflow.kernels.White(gp_input.shape[1])]), Z=Z1)
168+
model.kern.kern_list[0].variance = float(args.k_var)
169+
model.kern.kern_list[0].lengthscales = np.ones(gp_input.shape[1]) * float(args.k_len)
170+
model.kern.kern_list[1].variace = 1e-5
165171
else:
166172
raise ValueError("Chosen kernel is not implemented")
167173
model.likelihood.variance = 0.001
@@ -226,7 +232,7 @@
226232

227233
if args.verbose_opt == 'all':
228234
print("n_new_points right before condition: {}".format(n_new_points))
229-
if n_new_points > 0:
235+
if n_new_points > 0 and not args.only_pretraining:
230236
new_gp_input = persistent_scaler.transform(training_set_for_gp[-n_new_points:])
231237
new_observed_labels = observed_labels[-n_new_points:]
232238

@@ -255,6 +261,11 @@
255261
new_model.kern.kern_list[0].variance = model.kern.kern_list[0].variance.value
256262
new_model.kern.kern_list[0].lengthscales = model.kern.kern_list[0].lengthscales.value
257263
new_model.kern.kern_list[1].variance = model.kern.kern_list[1].variance.value
264+
elif args.kernel == "matern12":
265+
new_model = osgpr.OSGPR_VFE(new_gp_input, new_observed_labels, GPflow.kernels.Add([GPflow.kernels.Matern12(gp_input.shape[1], ARD=True), GPflow.kernels.White(gp_input.shape[1])]), mu, Su, Kaa, Zopt, Zinit)
266+
new_model.kern.kern_list[0].variance = model.kern.kern_list[0].variance.value
267+
new_model.kern.kern_list[0].lengthscales = model.kern.kern_list[0].lengthscales.value
268+
new_model.kern.kern_list[1].variance = model.kern.kern_list[1].variance.value
258269

259270
new_model.likelihood.variance = model.likelihood.variance.value
260271
model = new_model

0 commit comments

Comments
 (0)