@@ -200,40 +200,41 @@ def compute_J(self, m, f=None):
200200 if len (block ) == 0 :
201201 continue
202202
203- if client :
204- field_derivatives = client .scatter (
205- ATinv_df_duT_v [ind ], workers = self .worker
206- )
207- j_row_updates .append (
208- client .submit (
209- compute_rows ,
210- sim ,
211- tInd ,
212- block ,
213- field_derivatives ,
214- fields_array ,
215- time_mask ,
216- workers = self .worker ,
217- )
218- )
219- else :
220- j_row_updates .append (
221- array .from_delayed (
222- delayed_compute_rows (
203+ for row , field_derivatives in zip (block , ATinv_df_duT_v [ind ]):
204+ if client :
205+ # field_derivatives = client.scatter(
206+ # ATinv_df_duT_v[ind], workers=self.worker
207+ # )
208+ j_row_updates .append (
209+ client .submit (
210+ compute_rows ,
223211 sim ,
224212 tInd ,
225- block ,
226- ATinv_df_duT_v [ ind ] ,
213+ row ,
214+ field_derivatives ,
227215 fields_array ,
228216 time_mask ,
229- ),
230- dtype = np .float32 ,
231- shape = (
232- np .sum ([len (chunk [1 ][0 ]) for chunk in block ]),
233- m .size ,
234- ),
217+ workers = self .worker ,
218+ )
219+ )
220+ else :
221+ j_row_updates .append (
222+ array .from_delayed (
223+ delayed_compute_rows (
224+ sim ,
225+ tInd ,
226+ row ,
227+ field_derivatives ,
228+ fields_array ,
229+ time_mask ,
230+ ),
231+ dtype = np .float32 ,
232+ shape = (
233+ np .sum ([len (chunk [1 ][0 ]) for chunk in block ]),
234+ m .size ,
235+ ),
236+ )
235237 )
236- )
237238
238239 if client :
239240 j_row_updates = np .vstack (client .gather (j_row_updates ))
@@ -390,59 +391,39 @@ def get_field_deriv_block(
390391 """
391392 Stack the blocks of field derivatives for a given timestep and call the direct solver.
392393 """
393- stacked_blocks = []
394394 if len (ATinv_df_duT_v ) == 0 :
395395 ATinv_df_duT_v = [[] for _ in block ]
396- indices = []
397- count = 0
398396
399397 Asubdiag = None
400398 if tInd < self .nT - 1 :
401399 Asubdiag = self .getAsubdiag (tInd + 1 )
402400
401+ updated_ATinv_df_duT_v = []
402+
403403 for (_ , (rx_ind , _ , shape )), field_deriv , ATinv_chunk in zip (
404404 block , field_derivs , ATinv_df_duT_v
405405 ):
406+
406407 # Cut out early data
407408 time_check = np .kron (time_mask , np .ones (shape , dtype = bool ))[rx_ind ]
408409 local_ind = np .arange (rx_ind .shape [0 ])[time_check ]
409- indices .append (
410- (np .arange (count , count + len (local_ind )), local_ind ),
411- )
412- count += len (local_ind )
413410
414411 if len (ATinv_chunk ) == 0 :
415412 # last timestep (first to be solved)
416- stacked_block = field_deriv .toarray ()[:, local_ind ]
417-
418- else :
419- stacked_block = np .asarray (
420- field_deriv [:, local_ind ] - Asubdiag .T * ATinv_chunk [:, local_ind ]
421- )
422-
423- stacked_blocks .append (stacked_block )
424-
425- blocks = np .hstack (stacked_blocks )
426- if blocks .ndim == 2 and blocks .shape [1 ] > 0 :
427- solve = (AdiagTinv * blocks ).reshape (blocks .shape )
428- else :
429- solve = None
430-
431- updated_ATinv_df_duT_v = []
432-
433- for (_ , arrays ), field_deriv , ATinv_chunk , (columns , local_ind ) in zip (
434- block , field_derivs , ATinv_df_duT_v , indices , strict = True
435- ):
436-
437- if len (ATinv_chunk ) == 0 :
413+ time_block = field_deriv .toarray ()[:, local_ind ]
438414 shape = (
439415 field_deriv .shape [0 ],
440- len (arrays [ 0 ] ),
416+ len (rx_ind ),
441417 )
442418 ATinv_chunk = np .zeros (shape , dtype = np .float32 )
419+ else :
420+ time_block = np .asarray (
421+ field_deriv [:, local_ind ] - Asubdiag .T * ATinv_chunk [:, local_ind ]
422+ )
443423
444- if solve is not None :
445- ATinv_chunk [:, local_ind ] = solve [:, columns ]
424+ if time_block .ndim == 2 and time_block .shape [1 ] > 0 :
425+ solve = (AdiagTinv * time_block ).reshape (time_block .shape )
426+ ATinv_chunk [:, local_ind ] = solve
446427
447428 updated_ATinv_df_duT_v .append (ATinv_chunk )
448429
@@ -513,52 +494,47 @@ def compute_rows(
513494 simulation ,
514495 tInd ,
515496 chunks ,
516- ATinv_df_duT_v ,
497+ field_derivs ,
517498 fields ,
518499 time_mask ,
519500):
520501 """
521502 Compute the rows of the sensitivity matrix for a given source and receiver.
522503 """
523- rows = []
504+ (address , ind_array ) = chunks
505+ # for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v):
506+ src = simulation .survey .source_list [address [0 ]]
507+ time_check = np .kron (time_mask , np .ones (ind_array [2 ], dtype = bool ))[ind_array [0 ]]
508+ local_ind = np .arange (len (ind_array [0 ]))[time_check ]
524509
525- for (address , ind_array ), field_derivs in zip (chunks , ATinv_df_duT_v ):
526- src = simulation .survey .source_list [address [0 ]]
527- time_check = np .kron (time_mask , np .ones (ind_array [2 ], dtype = bool ))[ind_array [0 ]]
528- local_ind = np .arange (len (ind_array [0 ]))[time_check ]
529-
530- if len (local_ind ) < 1 :
531- row_block = np .zeros (
532- (len (ind_array [1 ]), simulation .model .size ), dtype = np .float32
533- )
534- rows .append (row_block )
535- continue
536-
537- dAsubdiagT_dm_v = simulation .getAsubdiagDeriv (
538- tInd ,
539- fields [:, address [0 ], tInd ],
540- field_derivs [:, local_ind ],
541- adjoint = True ,
542- )
543-
544- dRHST_dm_v = simulation .getRHSDeriv (
545- tInd + 1 , src , field_derivs [:, local_ind ], adjoint = True
546- ) # on nodes of time mesh
547-
548- un_src = fields [:, address [0 ], tInd + 1 ]
549- # cell centered on time mesh
550- dAT_dm_v = simulation .getAdiagDeriv (
551- tInd , un_src , field_derivs [:, local_ind ], adjoint = True
552- )
510+ if len (local_ind ) < 1 :
553511 row_block = np .zeros (
554512 (len (ind_array [1 ]), simulation .model .size ), dtype = np .float32
555513 )
556- row_block [time_check , :] = (- dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v ).T .astype (
557- np .float32
558- )
559- rows .append (row_block )
514+ return row_block
515+
516+ dAsubdiagT_dm_v = simulation .getAsubdiagDeriv (
517+ tInd ,
518+ fields [:, address [0 ], tInd ],
519+ field_derivs [:, local_ind ],
520+ adjoint = True ,
521+ )
522+
523+ dRHST_dm_v = simulation .getRHSDeriv (
524+ tInd + 1 , src , field_derivs [:, local_ind ], adjoint = True
525+ ) # on nodes of time mesh
526+
527+ un_src = fields [:, address [0 ], tInd + 1 ]
528+ # cell centered on time mesh
529+ dAT_dm_v = simulation .getAdiagDeriv (
530+ tInd , un_src , field_derivs [:, local_ind ], adjoint = True
531+ )
532+ row_block = np .zeros ((len (ind_array [1 ]), simulation .model .size ), dtype = np .float32 )
533+ row_block [time_check , :] = (- dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v ).T .astype (
534+ np .float32
535+ )
560536
561- return np . vstack ( rows )
537+ return row_block
562538
563539
564540def evaluate_dpred_block (indices , sources , mesh , time_mesh , fields ):
0 commit comments