@@ -187,7 +187,7 @@ def compute_J(self, m, f=None):
187187 for ind , (block , field_deriv ) in enumerate (
188188 zip (blocks , times_field_derivs [tInd + 1 ], strict = True )
189189 ):
190- atinv_block_deriv = get_field_deriv_block (
190+ ATinv_df_duT_v [ ind ] = get_field_deriv_block (
191191 self ,
192192 block ,
193193 field_deriv ,
@@ -198,23 +198,24 @@ def compute_J(self, m, f=None):
198198 client ,
199199 )
200200
201+ if client :
202+ field_derivatives = client .scatter (ATinv_df_duT_v , workers = self .worker )
203+ else :
204+ field_derivatives = ATinv_df_duT_v
205+
206+ for block_ind in range (len (blocks )):
207+
201208 if len (block ) == 0 :
202209 continue
203210
204- # if client:
205- # field_derivatives = client.scatter(
206- # atinv_block_deriv, workers=self.worker
207- # )
208- # else:
209- field_derivatives = atinv_block_deriv
210-
211211 if client :
212212 future_updates .append (
213213 client .submit (
214214 compute_rows ,
215215 sim ,
216216 tInd ,
217- block ,
217+ block_ind ,
218+ blocks ,
218219 field_derivatives ,
219220 fields_array ,
220221 time_mask ,
@@ -227,7 +228,8 @@ def compute_J(self, m, f=None):
227228 delayed_compute_rows (
228229 sim ,
229230 tInd ,
230- block ,
231+ block_ind ,
232+ blocks ,
231233 field_derivatives ,
232234 fields_array ,
233235 time_mask ,
@@ -239,7 +241,6 @@ def compute_J(self, m, f=None):
239241 ),
240242 )
241243 )
242- ATinv_df_duT_v [ind ] = atinv_block_deriv
243244
244245 if client :
245246 j_row_updates = np .vstack (client .gather (future_updates ))
@@ -498,7 +499,8 @@ def deriv_block(ATinv_df_duT_v, Asubdiag, local_ind, field_derivs):
498499def compute_rows (
499500 simulation ,
500501 tInd ,
501- block ,
502+ block_ind ,
503+ blocks ,
502504 field_derivs ,
503505 fields ,
504506 time_mask ,
@@ -507,7 +509,7 @@ def compute_rows(
507509 Compute the rows of the sensitivity matrix for a given source and receiver.
508510 """
509511 rows = []
510- for ind , (address , ind_array ) in enumerate (block ):
512+ for ind , (address , ind_array ) in enumerate (blocks [ block_ind ] ):
511513 # for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v):
512514 src = simulation .survey .source_list [address [0 ]]
513515 time_check = np .kron (time_mask , np .ones (ind_array [2 ], dtype = bool ))[ind_array [0 ]]
@@ -523,18 +525,18 @@ def compute_rows(
523525 dAsubdiagT_dm_v = simulation .getAsubdiagDeriv (
524526 tInd ,
525527 fields [:, address [0 ], tInd ],
526- field_derivs [ind ][:, local_ind ],
528+ field_derivs [block_ind ][ ind ][:, local_ind ],
527529 adjoint = True ,
528530 )
529531
530532 dRHST_dm_v = simulation .getRHSDeriv (
531- tInd + 1 , src , field_derivs [ind ][:, local_ind ], adjoint = True
533+ tInd + 1 , src , field_derivs [block_ind ][ ind ][:, local_ind ], adjoint = True
532534 ) # on nodes of time mesh
533535
534536 un_src = fields [:, address [0 ], tInd + 1 ]
535537 # cell centered on time mesh
536538 dAT_dm_v = simulation .getAdiagDeriv (
537- tInd , un_src , field_derivs [ind ][:, local_ind ], adjoint = True
539+ tInd , un_src , field_derivs [block_ind ][ ind ][:, local_ind ], adjoint = True
538540 )
539541 row_block = np .zeros (
540542 (len (ind_array [1 ]), simulation .model .size ), dtype = np .float32
0 commit comments