@@ -128,10 +128,17 @@ def __init__(self, keys_en: dict, sim, redund_sim=None):
128128 elif 'controls' in self .keys_en :
129129 self .prior_info = extract .extract_initial_controls (self .keys_en )
130130
131+
132+ # Ensemble size
133+ self .ne = self .keys_en .get ('ne' , None )
134+
131135 # Calculate initial ensemble if IMPORTSTATICVAR has not been given in init. file.
132136 # Prior info. on state variables must be given by PRIOR_<STATICVAR-name> keyword.
133137 if 'importstaticvar' not in self .keys_en :
134- self .ne = int (self .keys_en ['ne' ])
138+ if self .ne is None :
139+ self .ne = 100
140+ else :
141+ self .ne = int (self .ne )
135142
136143 # Generate prior ensemble
137144 self .enX , self .idX , self .cov_prior = entools .generate_prior_ensemble (
@@ -144,15 +151,18 @@ def __init__(self, keys_en: dict, sim, redund_sim=None):
144151 # State variable imported as a Numpy save file
145152 tmp_load = np .load (self .keys_en ['importstaticvar' ], allow_pickle = True )
146153
154+ if self .ne is None :
155+ self .ne = tmp_load [key ].shape [1 ]
156+ else :
157+ self .ne = int (self .ne )
158+
147159 # We assume that the user has saved the state dict. as **state (effectively saved all keys in state
148160 # individually).
149161 for key in self .keys_en ['staticvar' ]:
150162 if self .enX is None :
151- self .enX = tmp_load [key ]
152- self .ne = self .enX .shape [1 ]
163+ self .enX = tmp_load [key ][:,:self .ne ]
153164 else :
154- assert self .ne == tmp_load [key ].shape [1 ], 'Ensemble size of imported state variables do not match!'
155- self .enX = np .vstack ((self .enX , tmp_load [key ]))
165+ self .enX = np .vstack ((self .enX , tmp_load [key ][:,:self .ne ]))
156166
157167 # fill in indices
158168 self .idX [key ] = (self .enX .shape [0 ] - tmp_load [key ].shape [0 ], self .enX .shape [0 ])
@@ -260,17 +270,11 @@ def calc_prediction(self, enX=None, save_prediction=None):
260270 en_pred = []
261271 pbar = tqdm (enumerate (enX ), total = self .ne , ** progbar_settings )
262272 for member_index , state in pbar :
263- en_pred .append (deepcopy ( self .sim .run_fwd_sim (state , member_index ) ))
273+ en_pred .append (self .sim .run_fwd_sim (state , member_index ))
264274
265275 # Parallelization on HPC using SLURM
266276 elif self .sim .input_dict .get ('hpc' , False ): # Run prediction in parallel on hpc
267277 en_pred = self .run_on_HPC (enX , batch_size = nparallel )
268-
269- # Parallellization internal to the simulator (e.g. batch processing on GPU )
270- elif self .sim .input_dict .get ('parallel_internal' , False ):
271- # make a single matrix for each state
272- batch_enX = {key : np .array ([d [key ] for d in enX ]) for key in enX [0 ].keys ()} # key: (b, state)
273- en_pred = self .sim .run_fwd_sim (batch_enX , member_i = None )
274278
275279 # Parallelization on local machine using p_map
276280 else :
0 commit comments