Skip to content

Commit 9f88d68

Browse files
committed
But scatter on large array. Move indexing
1 parent 3854580 commit 9f88d68

1 file changed

Lines changed: 18 additions & 16 deletions

File tree

simpeg/dask/electromagnetics/time_domain/simulation.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def compute_J(self, m, f=None):
187187
for ind, (block, field_deriv) in enumerate(
188188
zip(blocks, times_field_derivs[tInd + 1], strict=True)
189189
):
190-
atinv_block_deriv = get_field_deriv_block(
190+
ATinv_df_duT_v[ind] = get_field_deriv_block(
191191
self,
192192
block,
193193
field_deriv,
@@ -198,23 +198,24 @@ def compute_J(self, m, f=None):
198198
client,
199199
)
200200

201+
if client:
202+
field_derivatives = client.scatter(ATinv_df_duT_v, workers=self.worker)
203+
else:
204+
field_derivatives = ATinv_df_duT_v
205+
206+
for block_ind in range(len(blocks)):
207+
201208
if len(block) == 0:
202209
continue
203210

204-
# if client:
205-
# field_derivatives = client.scatter(
206-
# atinv_block_deriv, workers=self.worker
207-
# )
208-
# else:
209-
field_derivatives = atinv_block_deriv
210-
211211
if client:
212212
future_updates.append(
213213
client.submit(
214214
compute_rows,
215215
sim,
216216
tInd,
217-
block,
217+
block_ind,
218+
blocks,
218219
field_derivatives,
219220
fields_array,
220221
time_mask,
@@ -227,7 +228,8 @@ def compute_J(self, m, f=None):
227228
delayed_compute_rows(
228229
sim,
229230
tInd,
230-
block,
231+
block_ind,
232+
blocks,
231233
field_derivatives,
232234
fields_array,
233235
time_mask,
@@ -239,7 +241,6 @@ def compute_J(self, m, f=None):
239241
),
240242
)
241243
)
242-
ATinv_df_duT_v[ind] = atinv_block_deriv
243244

244245
if client:
245246
j_row_updates = np.vstack(client.gather(future_updates))
@@ -498,7 +499,8 @@ def deriv_block(ATinv_df_duT_v, Asubdiag, local_ind, field_derivs):
498499
def compute_rows(
499500
simulation,
500501
tInd,
501-
block,
502+
block_ind,
503+
blocks,
502504
field_derivs,
503505
fields,
504506
time_mask,
@@ -507,7 +509,7 @@ def compute_rows(
507509
Compute the rows of the sensitivity matrix for a given source and receiver.
508510
"""
509511
rows = []
510-
for ind, (address, ind_array) in enumerate(block):
512+
for ind, (address, ind_array) in enumerate(blocks[block_ind]):
511513
# for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v):
512514
src = simulation.survey.source_list[address[0]]
513515
time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]]
@@ -523,18 +525,18 @@ def compute_rows(
523525
dAsubdiagT_dm_v = simulation.getAsubdiagDeriv(
524526
tInd,
525527
fields[:, address[0], tInd],
526-
field_derivs[ind][:, local_ind],
528+
field_derivs[block_ind][ind][:, local_ind],
527529
adjoint=True,
528530
)
529531

530532
dRHST_dm_v = simulation.getRHSDeriv(
531-
tInd + 1, src, field_derivs[ind][:, local_ind], adjoint=True
533+
tInd + 1, src, field_derivs[block_ind][ind][:, local_ind], adjoint=True
532534
) # on nodes of time mesh
533535

534536
un_src = fields[:, address[0], tInd + 1]
535537
# cell centered on time mesh
536538
dAT_dm_v = simulation.getAdiagDeriv(
537-
tInd, un_src, field_derivs[ind][:, local_ind], adjoint=True
539+
tInd, un_src, field_derivs[block_ind][ind][:, local_ind], adjoint=True
538540
)
539541
row_block = np.zeros(
540542
(len(ind_array[1]), simulation.model.size), dtype=np.float32

0 commit comments

Comments
 (0)