Skip to content

Commit e87f69b

Browse files
committed
Fix warning with large graph
1 parent 33aa5d8 commit e87f69b

1 file changed

Lines changed: 12 additions & 4 deletions

File tree

simpeg/dask/electromagnetics/time_domain/simulation.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,12 @@ def compute_J(self, m, f=None):
200200
if len(block) == 0:
201201
continue
202202

203-
for row, field_derivatives in zip(block, ATinv_df_duT_v[ind]):
203+
field_derivatives = ATinv_df_duT_v[ind]
204+
if client:
205+
field_derivatives = client.scatter(
206+
ATinv_df_duT_v[ind], workers=self.worker
207+
)
208+
for bb, row in enumerate(block):
204209
if client:
205210
# field_derivatives = client.scatter(
206211
# ATinv_df_duT_v[ind], workers=self.worker
@@ -211,6 +216,7 @@ def compute_J(self, m, f=None):
211216
sim,
212217
tInd,
213218
row,
219+
bb,
214220
field_derivatives,
215221
fields_array,
216222
time_mask,
@@ -224,6 +230,7 @@ def compute_J(self, m, f=None):
224230
sim,
225231
tInd,
226232
row,
233+
bb,
227234
field_derivatives,
228235
fields_array,
229236
time_mask,
@@ -494,6 +501,7 @@ def compute_rows(
494501
simulation,
495502
tInd,
496503
chunks,
504+
ind,
497505
field_derivs,
498506
fields,
499507
time_mask,
@@ -516,18 +524,18 @@ def compute_rows(
516524
dAsubdiagT_dm_v = simulation.getAsubdiagDeriv(
517525
tInd,
518526
fields[:, address[0], tInd],
519-
field_derivs[:, local_ind],
527+
field_derivs[ind][:, local_ind],
520528
adjoint=True,
521529
)
522530

523531
dRHST_dm_v = simulation.getRHSDeriv(
524-
tInd + 1, src, field_derivs[:, local_ind], adjoint=True
532+
tInd + 1, src, field_derivs[ind][:, local_ind], adjoint=True
525533
) # on nodes of time mesh
526534

527535
un_src = fields[:, address[0], tInd + 1]
528536
# cell centered on time mesh
529537
dAT_dm_v = simulation.getAdiagDeriv(
530-
tInd, un_src, field_derivs[:, local_ind], adjoint=True
538+
tInd, un_src, field_derivs[ind][:, local_ind], adjoint=True
531539
)
532540
row_block = np.zeros((len(ind_array[1]), simulation.model.size), dtype=np.float32)
533541
row_block[time_check, :] = (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T.astype(

0 commit comments

Comments
 (0)