|
50 | 50 | parser.add_argument("--redo_pretraining", action='store_true') |
51 | 51 | parser.add_argument("--cache_file_path", default="../cache") |
52 | 52 | parser.add_argument("--log_mu_sigma", action='store_true', help="If true all mu and sigma outputs of the gp will be saved for all candidates") |
| 53 | +parser.add_argument("--only_pretraining", action="store_true") |
53 | 54 |
|
54 | 55 | args = parser.parse_args() |
55 | 56 |
|
|
166 | 167 | model.kern.kern_list[0].variance = float(args.k_var) |
167 | 168 | model.kern.kern_list[0].lengthscales = np.ones(gp_input.shape[1]) * float(args.k_len) |
168 | 169 | model.kern.kern_list[1].variace = 1.0 |
| 170 | + elif args.kernel == "matern12": |
| 171 | + model = GPflow.sgpr.SGPR(gp_input, observed_labels, GPflow.kernels.Add([GPflow.kernels.Matern12(gp_input.shape[1], ARD=True), GPflow.kernels.White(gp_input.shape[1])]), Z=Z1) |
| 172 | + model.kern.kern_list[0].variance = float(args.k_var) |
| 173 | + model.kern.kern_list[0].lengthscales = np.ones(gp_input.shape[1]) * float(args.k_len) |
| 174 | + model.kern.kern_list[1].variace = 1e-5 |
169 | 175 | else: |
170 | 176 | raise ValueError("Chosen kernel is not implemented") |
171 | 177 | model.likelihood.variance = 0.001 |
|
230 | 236 |
|
231 | 237 | if args.verbose_opt == 'all': |
232 | 238 | print("n_new_points right before condition: {}".format(n_new_points)) |
233 | | - if n_new_points > 0: |
| 239 | + if n_new_points > 0 and not args.only_pretraining: |
234 | 240 | new_gp_input = persistent_scaler.transform(training_set_for_gp[-n_new_points:]) |
235 | 241 | new_observed_labels = observed_labels[-n_new_points:] |
236 | 242 |
|
|
259 | 265 | new_model.kern.kern_list[0].variance = model.kern.kern_list[0].variance.value |
260 | 266 | new_model.kern.kern_list[0].lengthscales = model.kern.kern_list[0].lengthscales.value |
261 | 267 | new_model.kern.kern_list[1].variance = model.kern.kern_list[1].variance.value |
| 268 | + elif args.kernel == "matern12": |
| 269 | + new_model = osgpr.OSGPR_VFE(new_gp_input, new_observed_labels, GPflow.kernels.Add([GPflow.kernels.Matern12(gp_input.shape[1], ARD=True), GPflow.kernels.White(gp_input.shape[1])]), mu, Su, Kaa, Zopt, Zinit) |
| 270 | + new_model.kern.kern_list[0].variance = model.kern.kern_list[0].variance.value |
| 271 | + new_model.kern.kern_list[0].lengthscales = model.kern.kern_list[0].lengthscales.value |
| 272 | + new_model.kern.kern_list[1].variance = model.kern.kern_list[1].variance.value |
262 | 273 |
|
263 | 274 | new_model.likelihood.variance = model.likelihood.variance.value |
264 | 275 | model = new_model |
|
0 commit comments