Skip to content

Commit 557da10

Browse files
authored
Fix some multilevel issues (#147)
1 parent a1c3e04 commit 557da10

4 files changed

Lines changed: 40 additions & 32 deletions

File tree

src/ensemble/ensemble.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -476,20 +476,27 @@ def calc_ml_prediction(self, enX=None):
476476
if self.aux_input is not None:
477477
level_enX[n]['aux_input'] = self.aux_input[n]
478478

479-
480479
# Index list of ensemble members
481480
list_member_index = list(ml_ne)
482481

483-
# Run prediction in parallel using p_map
484-
en_pred = p_map(
485-
self.sim.run_fwd_sim,
486-
level_enX,
487-
list_member_index,
488-
num_cpus=no_tot_run,
489-
disable=self.disable_tqdm,
490-
**progbar_settings,
491-
)
482+
########################################################################################################
492483

484+
# Number of parallel runs
485+
if self.sim.input_dict.get('hpc', False): # Run prediction in parallel on hpc
486+
en_pred = self.run_on_HPC(level_enX, batch_size=nparallel)
487+
488+
# Parallelization on local machine using p_map
489+
else:
490+
en_pred = p_map(
491+
self.sim.run_fwd_sim,
492+
level_enX,
493+
list_member_index,
494+
num_cpus=no_tot_run,
495+
disable=self.disable_tqdm,
496+
**progbar_settings,
497+
)
498+
########################################################################################################
499+
493500
# List successful runs and crashes
494501
list_crash = [indx for indx, el in enumerate(en_pred) if el is False]
495502
list_success = [indx for indx, el in enumerate(en_pred) if el is not False]
@@ -531,10 +538,8 @@ def calc_ml_prediction(self, enX=None):
531538

532539
en_pred[list_crash[index]] = deepcopy(en_pred[element])
533540

534-
# Convert ensemble specific result into pred_data, and filter for NONE data
535-
ml_pred_data.append([{typ: np.concatenate(tuple((el[ind][typ][:, np.newaxis]) for el in en_pred), axis=1)
536-
if any(elem is not None for elem in tuple((el[ind][typ]) for el in en_pred))
537-
else None for typ in en_pred[0][0].keys()} for ind in range(len(en_pred[0]))])
541+
#Convert ensemble specific result into pred_data, and filter for NONE data
542+
ml_pred_data.append(dtools.en_pred_to_pred_data(en_pred))
538543

539544
# loop over time instance first, and the level instance.
540545
self.pred_data = np.array(ml_pred_data).T.tolist()

src/popt/loop/ensemble_base.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def function(self, x, *args, **kwargs):
112112
self._aux_input()
113113

114114
# check for ensmble
115-
if len(x.shape) == 1:
115+
if len(x.shape) == 1:
116116
x = x[:,np.newaxis]
117117
self.ne = self.num_models
118118
else: self.ne = x.shape[1]
@@ -124,22 +124,22 @@ def function(self, x, *args, **kwargs):
124124
x = self._reorganize_multilevel_ensemble(x)
125125
x = self.scale_state(x).squeeze()
126126

127-
if self.enX is not None:
128-
self.enX = self.scale_state(self.enX)
127+
#if self.enX is not None:
128+
# self.enX = self.scale_state(self.enX)
129129

130130
# Evaluate the objective function
131131
if run_success:
132132
func_values = self.obj_func(
133133
self.pred_data,
134134
input_dict=self.sim.input_dict,
135135
true_order=self.sim.true_order,
136-
state=matrix_to_dict(self.enX, self.idX),
136+
state=matrix_to_dict(x, self.idX),
137137
**kwargs
138138
)
139139
else:
140140
func_values = np.inf # the simulations have crashed
141141

142-
if len(x.shape) == 1:
142+
if len(x.shape) == 1:
143143
self.stateF = func_values
144144
else:
145145
self.enF = func_values
@@ -258,13 +258,13 @@ def save_stateX(self, path='./', filetype='npz'):
258258
np.save(path + 'stateX.npy', stateX)
259259

260260
def _reorganize_multilevel_ensemble(self, x):
261-
if ('multilevel' in self.keys_en) and (len(x.shape) > 1):
262-
ml_ne = self.keys_en['multilevel']['ml_ne']
263-
x = ot.toggle_ml_state(x, ml_ne)
264-
return x
265-
else:
266-
return x
267-
261+
# Only toggle multilevel state when x is truly an ensemble (2D with >1 columns).
262+
# Treat shape (nx, 1) the same as a 1D vector.
263+
if 'multilevel' in self.keys_en:
264+
if isinstance(x,list) or ( x.ndim > 1 and (x.shape[1] > 1) ):
265+
ml_ne = self.multilevel['ml_ne']
266+
x = ot.toggle_ml_state(x, ml_ne)
267+
return x
268268

269269
def _aux_input(self):
270270
"""

src/popt/loop/ensemble_gaussian.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def gradient(self, x, *args, **kwargs):
8787
enX = np.random.multivariate_normal(self.stateX, self.covX, self.ne).T
8888

8989
# Shift ensemble to have correct mean
90-
self.enX = enX - enX.mean(axis=1, keepdims=True) + self.stateX[:,None]
90+
enX = enX - enX.mean(axis=1, keepdims=True) + self.stateX[:,None]
9191

9292
# Truncate to bounds
9393
if (self.lb is not None) and (self.ub is not None):
@@ -123,10 +123,13 @@ def gradient(self, x, *args, **kwargs):
123123
index += ne
124124

125125
if 'multilevel' in self.keys_en:
126-
weight = np.array(self.keys_en['multilevel']['ml_weights'])
127-
if not np.sum(weight) == 1.0:
128-
weight = weight / np.sum(weight)
129-
grad = np.dot(grad_ml, weight)
126+
weight = np.array(self.multilevel['ml_weights'])
127+
if len(weight) > 1:
128+
if not np.sum(weight) == 1.0:
129+
weight = weight / np.sum(weight)
130+
grad = np.dot(grad_ml, weight)
131+
else:
132+
grad = grad_ml[0]
130133
else:
131134
grad = grad_ml[0]
132135

src/popt/loop/optimize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def run_loop(self):
209209
self.logger(f'─────> EPF-EnOpt: {self.epf_iteration}, {r} (outer iteration, penalty factor)') # print epf info
210210
else:
211211
self.logger(f'─────> EPF-EnOpt: converged, no variables changed more than {conv_crit*100} %') # print epf info
212-
final_obj_no_penalty = str(round(float(self.fun(self.xk)),4))
212+
final_obj_no_penalty = str( round( float( np.mean(self.fun(self.xk)) ),4) )
213213
self.logger(f'─────> EPF-EnOpt: objective value without penalty = {final_obj_no_penalty}') # print epf info
214214
def save(self):
215215
"""

0 commit comments

Comments
 (0)