Skip to content

Commit 801b6ed

Browse files
committed
Restructure gpCAM gen
1 parent 04c6a3c commit 801b6ed

1 file changed

Lines changed: 21 additions & 21 deletions

File tree

libensemble/gen_funcs/persistent_gpCAM.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,11 @@ def _generate_mesh(lb, ub, num_points=10):
7575
return points
7676

7777

78-
def _update_gp_and_eval_var(all_x, all_y, x_for_var, test_points, persis_info):
78+
def _eval_var(my_gp2S, all_x, all_y, x_for_var, test_points, persis_info):
7979
"""
80-
Update the GP using the points in all_x and their function values in
81-
all_y. (We are assuming deterministic values in all_y, so we set the noise
82-
to be 1e-8 when build the GP.) Then evaluates the posterior covariance at
83-
points in x_for_var. If we have test points, calculate mean square error
84-
at those points.
80+
Evaluate the posterior covariance at points in x_for_var.
81+
If we have test points, calculate mean square error at those points.
8582
"""
86-
my_gp2S = GP(all_x, all_y, noise_variances=1e-12 * np.ones(len(all_y)))
87-
my_gp2S.train()
88-
8983
# Obtain covariance in groups to prevent memory overload.
9084
n_rows = x_for_var.shape[0]
9185
var_vals = []
@@ -105,6 +99,7 @@ def _update_gp_and_eval_var(all_x, all_y, x_for_var, test_points, persis_info):
10599
f_est = my_gp2S.posterior_mean(test_points["x"])["f(x)"]
106100
mse = np.mean((f_est - test_points["f"]) ** 2)
107101
persis_info.setdefault("mean_squared_error", []).append(mse)
102+
108103
return np.array(var_vals)
109104

110105

@@ -156,29 +151,32 @@ def persistent_gpCAM_simple(H_in, persis_info, gen_specs, libE_info):
156151
`test_gpCAM.py <https://github.com/Libensemble/libensemble/blob/develop/libensemble/tests/regression_tests/test_gpCAM.py>`_
157152
""" # noqa
158153
U = gen_specs["user"]
154+
my_gp2S = None
155+
noise = 1e-12
159156

160157
test_points = _read_testpoints(U)
161158

162159
batch_size, n, lb, ub, all_x, all_y, ps = _initialize_gpcAM(U, libE_info)
163160

164161
# Send batches until manager sends stop tag
165162
tag = None
166-
persis_info["max_variance"] = []
163+
var_vals = None
167164

168165
if U.get("use_grid"):
169166
num_points = 10
170167
x_for_var = _generate_mesh(lb, ub, num_points)
171168
r_low_init, r_high_init = calculate_grid_distances(lb, ub, num_points)
169+
else:
170+
x_for_var = persis_info["rand_stream"].uniform(lb, ub, (10 * batch_size, n))
172171

173172
while tag not in [STOP_TAG, PERSIS_STOP]:
174173
if all_x.shape[0] == 0:
175174
x_new = persis_info["rand_stream"].uniform(lb, ub, (batch_size, n))
176175
else:
177176
if not U.get("use_grid"):
178177
x_for_var = persis_info["rand_stream"].uniform(lb, ub, (10 * batch_size, n))
179-
var_vals = _update_gp_and_eval_var(all_x, all_y, x_for_var, test_points, persis_info)
180-
181-
if U.get("use_grid"):
178+
x_new = x_for_var[np.argsort(var_vals)[-batch_size:]]
179+
else:
182180
r_high = r_high_init
183181
r_low = r_low_init
184182
x_new = []
@@ -190,13 +188,12 @@ def persistent_gpCAM_simple(H_in, persis_info, gen_specs, libE_info):
190188
if len(x_new) < batch_size:
191189
r_high = r_cand
192190
r_cand = (r_high + r_low) / 2.0
193-
else:
194-
x_new = x_for_var[np.argsort(var_vals)[-batch_size:]]
195191

196192
H_o = np.zeros(batch_size, dtype=gen_specs["out"])
197193
H_o["x"] = x_new
198194
tag, Work, calc_in = ps.send_recv(H_o)
199195

196+
# This works with or without final_gen_send
200197
if calc_in is not None:
201198
y_new = np.atleast_2d(calc_in["f"]).T
202199
nan_indices = [i for i, fval in enumerate(y_new) if np.isnan(fval)]
@@ -205,12 +202,15 @@ def persistent_gpCAM_simple(H_in, persis_info, gen_specs, libE_info):
205202
all_x = np.vstack((all_x, x_new))
206203
all_y = np.vstack((all_y, y_new))
207204

208-
# If final points are sent with PERSIS_STOP, update model and get final var_vals
209-
if calc_in is not None:
210-
# H_o not updated by default - is persis_info
211-
if not U.get("use_grid"):
212-
x_for_var = persis_info["rand_stream"].uniform(lb, ub, (10 * batch_size, n))
213-
var_vals = _update_gp_and_eval_var(all_x, all_y, x_for_var, test_points, persis_info)
205+
if my_gp2S is None:
206+
my_gp2S = GP(all_x, all_y, noise_variances=noise * np.ones(len(all_y)))
207+
else:
208+
my_gp2S.tell(all_x, all_y, noise_variances=noise * np.ones(len(all_y)))
209+
my_gp2S.train()
210+
211+
if not U.get("use_grid"):
212+
x_for_var = persis_info["rand_stream"].uniform(lb, ub, (10 * batch_size, n))
213+
var_vals = _eval_var(my_gp2S, all_x, all_y, x_for_var, test_points, persis_info)
214214

215215
return H_o, persis_info, FINISHED_PERSISTENT_GEN_TAG
216216

0 commit comments

Comments
 (0)