Skip to content

Commit 51c9cf0

Browse files
MathiasMNilsenKriFos1
authored andcommitted
Fix bug
1 parent 5660cf2 commit 51c9cf0

1 file changed

Lines changed: 16 additions & 12 deletions

File tree

ensemble/ensemble.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,17 @@ def __init__(self, keys_en: dict, sim, redund_sim=None):
128128
elif 'controls' in self.keys_en:
129129
self.prior_info = extract.extract_initial_controls(self.keys_en)
130130

131+
132+
# Ensemble size
133+
self.ne = self.keys_en.get('ne', None)
134+
131135
# Calculate initial ensemble if IMPORTSTATICVAR has not been given in init. file.
132136
# Prior info. on state variables must be given by PRIOR_<STATICVAR-name> keyword.
133137
if 'importstaticvar' not in self.keys_en:
134-
self.ne = int(self.keys_en['ne'])
138+
if self.ne is None:
139+
self.ne = 100
140+
else:
141+
self.ne = int(self.ne)
135142

136143
# Generate prior ensemble
137144
self.enX, self.idX, self.cov_prior = entools.generate_prior_ensemble(
@@ -144,15 +151,18 @@ def __init__(self, keys_en: dict, sim, redund_sim=None):
144151
# State variable imported as a Numpy save file
145152
tmp_load = np.load(self.keys_en['importstaticvar'], allow_pickle=True)
146153

154+
if self.ne is None:
155+
self.ne = tmp_load[key].shape[1]
156+
else:
157+
self.ne = int(self.ne)
158+
147159
# We assume that the user has saved the state dict. as **state (effectively saved all keys in state
148160
# individually).
149161
for key in self.keys_en['staticvar']:
150162
if self.enX is None:
151-
self.enX = tmp_load[key]
152-
self.ne = self.enX.shape[1]
163+
self.enX = tmp_load[key][:,:self.ne]
153164
else:
154-
assert self.ne == tmp_load[key].shape[1], 'Ensemble size of imported state variables do not match!'
155-
self.enX = np.vstack((self.enX, tmp_load[key]))
165+
self.enX = np.vstack((self.enX, tmp_load[key][:,:self.ne]))
156166

157167
# fill in indices
158168
self.idX[key] = (self.enX.shape[0] - tmp_load[key].shape[0], self.enX.shape[0])
@@ -260,17 +270,11 @@ def calc_prediction(self, enX=None, save_prediction=None):
260270
en_pred = []
261271
pbar = tqdm(enumerate(enX), total=self.ne, **progbar_settings)
262272
for member_index, state in pbar:
263-
en_pred.append(deepcopy(self.sim.run_fwd_sim(state, member_index)))
273+
en_pred.append(self.sim.run_fwd_sim(state, member_index))
264274

265275
# Parallelization on HPC using SLURM
266276
elif self.sim.input_dict.get('hpc', False): # Run prediction in parallel on hpc
267277
en_pred = self.run_on_HPC(enX, batch_size=nparallel)
268-
269-
# Parallellization internal to the simulator (e.g. batch processing on GPU )
270-
elif self.sim.input_dict.get('parallel_internal', False):
271-
# make a single matrix for each state
272-
batch_enX = {key: np.array([d[key] for d in enX]) for key in enX[0].keys()} # key: (b, state)
273-
en_pred = self.sim.run_fwd_sim(batch_enX, member_i=None)
274278

275279
# Parallelization on local machine using p_map
276280
else:

0 commit comments

Comments
 (0)