@@ -176,8 +176,9 @@ def compute_J(self, m, f=None):
176176 delayed_compute_rows = delayed (compute_rows )
177177 sim = self
178178 for tInd , dt in zip (reversed (range (self .nT )), reversed (self .time_steps )):
179+
179180 AdiagTinv = Ainv [dt ]
180- j_row_updates = []
181+ future_updates = []
181182 time_mask = data_times > simulation_times [tInd ]
182183
183184 if not np .any (time_mask ):
@@ -200,15 +201,15 @@ def compute_J(self, m, f=None):
200201 if len (block ) == 0 :
201202 continue
202203
203- if client :
204- field_derivatives = client .scatter (
205- atinv_block_deriv , workers = self .worker
206- )
207- else :
208- field_derivatives = atinv_block_deriv
204+ # if client:
205+ # field_derivatives = client.scatter(
206+ # atinv_block_deriv, workers=self.worker
207+ # )
208+ # else:
209+ field_derivatives = atinv_block_deriv
209210
210211 if client :
211- j_row_updates .append (
212+ future_updates .append (
212213 client .submit (
213214 compute_rows ,
214215 sim ,
@@ -221,7 +222,7 @@ def compute_J(self, m, f=None):
221222 )
222223 )
223224 else :
224- j_row_updates .append (
225+ future_updates .append (
225226 array .from_delayed (
226227 delayed_compute_rows (
227228 sim ,
@@ -241,9 +242,9 @@ def compute_J(self, m, f=None):
241242 ATinv_df_duT_v [ind ] = atinv_block_deriv
242243
243244 if client :
244- j_row_updates = np .vstack (client .gather (j_row_updates ))
245+ j_row_updates = np .vstack (client .gather (future_updates ))
245246 else :
246- j_row_updates = array .vstack (j_row_updates ).compute ()
247+ j_row_updates = array .vstack (future_updates ).compute ()
247248
248249 if self .store_sensitivities == "disk" :
249250 sens_name = self .sensitivity_path [:- 5 ] + f"_{ tInd % 2 } .zarr"
0 commit comments