@@ -200,7 +200,12 @@ def compute_J(self, m, f=None):
200200 if len (block ) == 0 :
201201 continue
202202
203- for row , field_derivatives in zip (block , ATinv_df_duT_v [ind ]):
203+ field_derivatives = ATinv_df_duT_v [ind ]
204+ if client :
205+ field_derivatives = client .scatter (
206+ ATinv_df_duT_v [ind ], workers = self .worker
207+ )
208+ for bb , row in enumerate (block ):
204209 if client :
205210 # field_derivatives = client.scatter(
206211 # ATinv_df_duT_v[ind], workers=self.worker
@@ -211,6 +216,7 @@ def compute_J(self, m, f=None):
211216 sim ,
212217 tInd ,
213218 row ,
219+ bb ,
214220 field_derivatives ,
215221 fields_array ,
216222 time_mask ,
@@ -224,6 +230,7 @@ def compute_J(self, m, f=None):
224230 sim ,
225231 tInd ,
226232 row ,
233+ bb ,
227234 field_derivatives ,
228235 fields_array ,
229236 time_mask ,
@@ -494,6 +501,7 @@ def compute_rows(
494501 simulation ,
495502 tInd ,
496503 chunks ,
504+ ind ,
497505 field_derivs ,
498506 fields ,
499507 time_mask ,
@@ -516,18 +524,18 @@ def compute_rows(
516524 dAsubdiagT_dm_v = simulation .getAsubdiagDeriv (
517525 tInd ,
518526 fields [:, address [0 ], tInd ],
519- field_derivs [:, local_ind ],
527+ field_derivs [ind ][ :, local_ind ],
520528 adjoint = True ,
521529 )
522530
523531 dRHST_dm_v = simulation .getRHSDeriv (
524- tInd + 1 , src , field_derivs [:, local_ind ], adjoint = True
532+ tInd + 1 , src , field_derivs [ind ][ :, local_ind ], adjoint = True
525533 ) # on nodes of time mesh
526534
527535 un_src = fields [:, address [0 ], tInd + 1 ]
528536 # cell centered on time mesh
529537 dAT_dm_v = simulation .getAdiagDeriv (
530- tInd , un_src , field_derivs [:, local_ind ], adjoint = True
538+ tInd , un_src , field_derivs [ind ][ :, local_ind ], adjoint = True
531539 )
532540 row_block = np .zeros ((len (ind_array [1 ]), simulation .model .size ), dtype = np .float32 )
533541 row_block [time_check , :] = (- dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v ).T .astype (
0 commit comments