@@ -176,8 +176,9 @@ def compute_J(self, m, f=None):
176176 delayed_compute_rows = delayed (compute_rows )
177177 sim = self
178178 for tInd , dt in zip (reversed (range (self .nT )), reversed (self .time_steps )):
179+
179180 AdiagTinv = Ainv [dt ]
180- j_row_updates = []
181+ future_updates = []
181182 time_mask = data_times > simulation_times [tInd ]
182183
183184 if not np .any (time_mask ):
@@ -197,56 +198,54 @@ def compute_J(self, m, f=None):
197198 client ,
198199 )
199200
201+ if client :
202+ field_derivatives = client .scatter (ATinv_df_duT_v , workers = self .worker )
203+ else :
204+ field_derivatives = ATinv_df_duT_v
205+
206+ for block_ind in range (len (blocks )):
207+
200208 if len (block ) == 0 :
201209 continue
202210
203- field_derivatives = ATinv_df_duT_v [ind ]
204211 if client :
205- field_derivatives = client .scatter (
206- ATinv_df_duT_v [ind ], workers = self .worker
212+ future_updates .append (
213+ client .submit (
214+ compute_rows ,
215+ sim ,
216+ tInd ,
217+ block_ind ,
218+ blocks ,
219+ field_derivatives ,
220+ fields_array ,
221+ time_mask ,
222+ workers = self .worker ,
223+ )
207224 )
208- for bb , row in enumerate (block ):
209- if client :
210- # field_derivatives = client.scatter(
211- # ATinv_df_duT_v[ind], workers=self.worker
212- # )
213- j_row_updates .append (
214- client .submit (
215- compute_rows ,
225+ else :
226+ future_updates .append (
227+ array .from_delayed (
228+ delayed_compute_rows (
216229 sim ,
217230 tInd ,
218- row ,
219- bb ,
231+ block_ind ,
232+ blocks ,
220233 field_derivatives ,
221234 fields_array ,
222235 time_mask ,
223- workers = self .worker ,
224- )
225- )
226- else :
227- j_row_updates .append (
228- array .from_delayed (
229- delayed_compute_rows (
230- sim ,
231- tInd ,
232- row ,
233- bb ,
234- field_derivatives ,
235- fields_array ,
236- time_mask ,
237- ),
238- dtype = np .float32 ,
239- shape = (
240- np .sum ([len (chunk [1 ][0 ]) for chunk in block ]),
241- m .size ,
242- ),
243- )
236+ ),
237+ dtype = np .float32 ,
238+ shape = (
239+ np .sum ([len (chunk [1 ][0 ]) for chunk in block ]),
240+ m .size ,
241+ ),
244242 )
243+ )
245244
246245 if client :
247- j_row_updates = np .vstack (client .gather (j_row_updates ))
246+ j_row_updates = np .vstack (client .gather (future_updates ))
248247 else :
249- j_row_updates = array .vstack (j_row_updates ).compute ()
248+ j_row_updates = array .vstack (future_updates ).compute ()
250249
251250 if self .store_sensitivities == "disk" :
252251 sens_name = self .sensitivity_path [:- 5 ] + f"_{ tInd % 2 } .zarr"
@@ -500,49 +499,54 @@ def deriv_block(ATinv_df_duT_v, Asubdiag, local_ind, field_derivs):
500499def compute_rows (
501500 simulation ,
502501 tInd ,
503- chunks ,
504- ind ,
502+ block_ind ,
503+ blocks ,
505504 field_derivs ,
506505 fields ,
507506 time_mask ,
508507):
509508 """
510509 Compute the rows of the sensitivity matrix for a given source and receiver.
511510 """
512- (address , ind_array ) = chunks
513- # for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v):
514- src = simulation .survey .source_list [address [0 ]]
515- time_check = np .kron (time_mask , np .ones (ind_array [2 ], dtype = bool ))[ind_array [0 ]]
516- local_ind = np .arange (len (ind_array [0 ]))[time_check ]
511+ rows = []
512+ for ind , (address , ind_array ) in enumerate (blocks [block_ind ]):
513+ # for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v):
514+ src = simulation .survey .source_list [address [0 ]]
515+ time_check = np .kron (time_mask , np .ones (ind_array [2 ], dtype = bool ))[ind_array [0 ]]
516+ local_ind = np .arange (len (ind_array [0 ]))[time_check ]
517+
518+ if len (local_ind ) < 1 :
519+ row_block = np .zeros (
520+ (len (ind_array [1 ]), simulation .model .size ), dtype = np .float32
521+ )
522+ rows .append (row_block )
523+ continue
517524
518- if len (local_ind ) < 1 :
519- row_block = np .zeros (
520- (len (ind_array [1 ]), simulation .model .size ), dtype = np .float32
525+ dAsubdiagT_dm_v = simulation .getAsubdiagDeriv (
526+ tInd ,
527+ fields [:, address [0 ], tInd ],
528+ field_derivs [block_ind ][ind ][:, local_ind ],
529+ adjoint = True ,
521530 )
522- return row_block
523-
524- dAsubdiagT_dm_v = simulation .getAsubdiagDeriv (
525- tInd ,
526- fields [:, address [0 ], tInd ],
527- field_derivs [ind ][:, local_ind ],
528- adjoint = True ,
529- )
530531
531- dRHST_dm_v = simulation .getRHSDeriv (
532- tInd + 1 , src , field_derivs [ind ][:, local_ind ], adjoint = True
533- ) # on nodes of time mesh
532+ dRHST_dm_v = simulation .getRHSDeriv (
533+ tInd + 1 , src , field_derivs [ block_ind ] [ind ][:, local_ind ], adjoint = True
534+ ) # on nodes of time mesh
534535
535- un_src = fields [:, address [0 ], tInd + 1 ]
536- # cell centered on time mesh
537- dAT_dm_v = simulation .getAdiagDeriv (
538- tInd , un_src , field_derivs [ind ][:, local_ind ], adjoint = True
539- )
540- row_block = np .zeros ((len (ind_array [1 ]), simulation .model .size ), dtype = np .float32 )
541- row_block [time_check , :] = (- dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v ).T .astype (
542- np .float32
543- )
536+ un_src = fields [:, address [0 ], tInd + 1 ]
537+ # cell centered on time mesh
538+ dAT_dm_v = simulation .getAdiagDeriv (
539+ tInd , un_src , field_derivs [block_ind ][ind ][:, local_ind ], adjoint = True
540+ )
541+ row_block = np .zeros (
542+ (len (ind_array [1 ]), simulation .model .size ), dtype = np .float32
543+ )
544+ row_block [time_check , :] = (- dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v ).T .astype (
545+ np .float32
546+ )
547+ rows .append (row_block )
544548
545- return row_block
549+ return np . vstack ( rows )
546550
547551
548552def evaluate_dpred_block (indices , sources , mesh , time_mesh , fields ):
0 commit comments