|
49 | 49 | parser.add_argument("--redo_pretraining", action='store_true') |
50 | 50 | parser.add_argument("--cache_file_path", default="../cache") |
51 | 51 | parser.add_argument("--log_mu_sigma", action='store_true', help="If true all mu and sigma outputs of the gp will be saved for all candidates") |
| 52 | +parser.add_argument("--only_pretraining", action="store_true") |
52 | 53 |
|
53 | 54 | args = parser.parse_args() |
54 | 55 |
|
|
162 | 163 | model.kern.kern_list[0].variance = float(args.k_var) |
163 | 164 | model.kern.kern_list[0].lengthscales = np.ones(gp_input.shape[1]) * float(args.k_len) |
164 | 165 | model.kern.kern_list[1].variace = 1.0 |
| 166 | + elif args.kernel == "matern12": |
| 167 | + model = GPflow.sgpr.SGPR(gp_input, observed_labels, GPflow.kernels.Add([GPflow.kernels.Matern12(gp_input.shape[1], ARD=True), GPflow.kernels.White(gp_input.shape[1])]), Z=Z1) |
| 168 | + model.kern.kern_list[0].variance = float(args.k_var) |
| 169 | + model.kern.kern_list[0].lengthscales = np.ones(gp_input.shape[1]) * float(args.k_len) |
| 170 | + model.kern.kern_list[1].variace = 1e-5 |
165 | 171 | else: |
166 | 172 | raise ValueError("Chosen kernel is not implemented") |
167 | 173 | model.likelihood.variance = 0.001 |
|
226 | 232 |
|
227 | 233 | if args.verbose_opt == 'all': |
228 | 234 | print("n_new_points right before condition: {}".format(n_new_points)) |
229 | | - if n_new_points > 0: |
| 235 | + if n_new_points > 0 and not args.only_pretraining: |
230 | 236 | new_gp_input = persistent_scaler.transform(training_set_for_gp[-n_new_points:]) |
231 | 237 | new_observed_labels = observed_labels[-n_new_points:] |
232 | 238 |
|
|
255 | 261 | new_model.kern.kern_list[0].variance = model.kern.kern_list[0].variance.value |
256 | 262 | new_model.kern.kern_list[0].lengthscales = model.kern.kern_list[0].lengthscales.value |
257 | 263 | new_model.kern.kern_list[1].variance = model.kern.kern_list[1].variance.value |
| 264 | + elif args.kernel == "matern12": |
| 265 | + new_model = osgpr.OSGPR_VFE(new_gp_input, new_observed_labels, GPflow.kernels.Add([GPflow.kernels.Matern12(gp_input.shape[1], ARD=True), GPflow.kernels.White(gp_input.shape[1])]), mu, Su, Kaa, Zopt, Zinit) |
| 266 | + new_model.kern.kern_list[0].variance = model.kern.kern_list[0].variance.value |
| 267 | + new_model.kern.kern_list[0].lengthscales = model.kern.kern_list[0].lengthscales.value |
| 268 | + new_model.kern.kern_list[1].variance = model.kern.kern_list[1].variance.value |
258 | 269 |
|
259 | 270 | new_model.likelihood.variance = model.likelihood.variance.value |
260 | 271 | model = new_model |
|
0 commit comments