Skip to content

Commit 5561fb5

Browse files
authored
Merge pull request #115 from MiraGeoscience/GEOPY-2182d
GEOPY-2182
2 parents 1d91046 + 9f88d68 commit 5561fb5

2 files changed

Lines changed: 81 additions & 77 deletions

File tree

simpeg/dask/electromagnetics/time_domain/simulation.py

Lines changed: 72 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,9 @@ def compute_J(self, m, f=None):
176176
delayed_compute_rows = delayed(compute_rows)
177177
sim = self
178178
for tInd, dt in zip(reversed(range(self.nT)), reversed(self.time_steps)):
179+
179180
AdiagTinv = Ainv[dt]
180-
j_row_updates = []
181+
future_updates = []
181182
time_mask = data_times > simulation_times[tInd]
182183

183184
if not np.any(time_mask):
@@ -197,56 +198,54 @@ def compute_J(self, m, f=None):
197198
client,
198199
)
199200

201+
if client:
202+
field_derivatives = client.scatter(ATinv_df_duT_v, workers=self.worker)
203+
else:
204+
field_derivatives = ATinv_df_duT_v
205+
206+
for block_ind in range(len(blocks)):
207+
200208
if len(block) == 0:
201209
continue
202210

203-
field_derivatives = ATinv_df_duT_v[ind]
204211
if client:
205-
field_derivatives = client.scatter(
206-
ATinv_df_duT_v[ind], workers=self.worker
212+
future_updates.append(
213+
client.submit(
214+
compute_rows,
215+
sim,
216+
tInd,
217+
block_ind,
218+
blocks,
219+
field_derivatives,
220+
fields_array,
221+
time_mask,
222+
workers=self.worker,
223+
)
207224
)
208-
for bb, row in enumerate(block):
209-
if client:
210-
# field_derivatives = client.scatter(
211-
# ATinv_df_duT_v[ind], workers=self.worker
212-
# )
213-
j_row_updates.append(
214-
client.submit(
215-
compute_rows,
225+
else:
226+
future_updates.append(
227+
array.from_delayed(
228+
delayed_compute_rows(
216229
sim,
217230
tInd,
218-
row,
219-
bb,
231+
block_ind,
232+
blocks,
220233
field_derivatives,
221234
fields_array,
222235
time_mask,
223-
workers=self.worker,
224-
)
225-
)
226-
else:
227-
j_row_updates.append(
228-
array.from_delayed(
229-
delayed_compute_rows(
230-
sim,
231-
tInd,
232-
row,
233-
bb,
234-
field_derivatives,
235-
fields_array,
236-
time_mask,
237-
),
238-
dtype=np.float32,
239-
shape=(
240-
np.sum([len(chunk[1][0]) for chunk in block]),
241-
m.size,
242-
),
243-
)
236+
),
237+
dtype=np.float32,
238+
shape=(
239+
np.sum([len(chunk[1][0]) for chunk in block]),
240+
m.size,
241+
),
244242
)
243+
)
245244

246245
if client:
247-
j_row_updates = np.vstack(client.gather(j_row_updates))
246+
j_row_updates = np.vstack(client.gather(future_updates))
248247
else:
249-
j_row_updates = array.vstack(j_row_updates).compute()
248+
j_row_updates = array.vstack(future_updates).compute()
250249

251250
if self.store_sensitivities == "disk":
252251
sens_name = self.sensitivity_path[:-5] + f"_{tInd % 2}.zarr"
@@ -500,49 +499,54 @@ def deriv_block(ATinv_df_duT_v, Asubdiag, local_ind, field_derivs):
500499
def compute_rows(
501500
simulation,
502501
tInd,
503-
chunks,
504-
ind,
502+
block_ind,
503+
blocks,
505504
field_derivs,
506505
fields,
507506
time_mask,
508507
):
509508
"""
510509
Compute the rows of the sensitivity matrix for a given source and receiver.
511510
"""
512-
(address, ind_array) = chunks
513-
# for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v):
514-
src = simulation.survey.source_list[address[0]]
515-
time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]]
516-
local_ind = np.arange(len(ind_array[0]))[time_check]
511+
rows = []
512+
for ind, (address, ind_array) in enumerate(blocks[block_ind]):
513+
# for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v):
514+
src = simulation.survey.source_list[address[0]]
515+
time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]]
516+
local_ind = np.arange(len(ind_array[0]))[time_check]
517+
518+
if len(local_ind) < 1:
519+
row_block = np.zeros(
520+
(len(ind_array[1]), simulation.model.size), dtype=np.float32
521+
)
522+
rows.append(row_block)
523+
continue
517524

518-
if len(local_ind) < 1:
519-
row_block = np.zeros(
520-
(len(ind_array[1]), simulation.model.size), dtype=np.float32
525+
dAsubdiagT_dm_v = simulation.getAsubdiagDeriv(
526+
tInd,
527+
fields[:, address[0], tInd],
528+
field_derivs[block_ind][ind][:, local_ind],
529+
adjoint=True,
521530
)
522-
return row_block
523-
524-
dAsubdiagT_dm_v = simulation.getAsubdiagDeriv(
525-
tInd,
526-
fields[:, address[0], tInd],
527-
field_derivs[ind][:, local_ind],
528-
adjoint=True,
529-
)
530531

531-
dRHST_dm_v = simulation.getRHSDeriv(
532-
tInd + 1, src, field_derivs[ind][:, local_ind], adjoint=True
533-
) # on nodes of time mesh
532+
dRHST_dm_v = simulation.getRHSDeriv(
533+
tInd + 1, src, field_derivs[block_ind][ind][:, local_ind], adjoint=True
534+
) # on nodes of time mesh
534535

535-
un_src = fields[:, address[0], tInd + 1]
536-
# cell centered on time mesh
537-
dAT_dm_v = simulation.getAdiagDeriv(
538-
tInd, un_src, field_derivs[ind][:, local_ind], adjoint=True
539-
)
540-
row_block = np.zeros((len(ind_array[1]), simulation.model.size), dtype=np.float32)
541-
row_block[time_check, :] = (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T.astype(
542-
np.float32
543-
)
536+
un_src = fields[:, address[0], tInd + 1]
537+
# cell centered on time mesh
538+
dAT_dm_v = simulation.getAdiagDeriv(
539+
tInd, un_src, field_derivs[block_ind][ind][:, local_ind], adjoint=True
540+
)
541+
row_block = np.zeros(
542+
(len(ind_array[1]), simulation.model.size), dtype=np.float32
543+
)
544+
row_block[time_check, :] = (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T.astype(
545+
np.float32
546+
)
547+
rows.append(row_block)
544548

545-
return row_block
549+
return np.vstack(rows)
546550

547551

548552
def evaluate_dpred_block(indices, sources, mesh, time_mesh, fields):

simpeg/dask/utils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,14 @@ def get_parallel_blocks(
7676
row_count += chunk_size
7777

7878
# # Re-split over cpu_count if too few blocks
79-
# if len(blocks) < thread_count and optimize:
80-
# flatten_blocks = []
81-
# for block in blocks:
82-
# flatten_blocks += block
83-
#
84-
# chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count())
85-
# return [
86-
# [flatten_blocks[i] for i in chunk] for chunk in chunks if len(chunk) > 0
87-
# ]
79+
if len(blocks) < thread_count and optimize:
80+
flatten_blocks = []
81+
for block in blocks:
82+
flatten_blocks += block
83+
84+
chunks = np.array_split(np.arange(len(flatten_blocks)), cpu_count())
85+
return [
86+
[flatten_blocks[i] for i in chunk] for chunk in chunks if len(chunk) > 0
87+
]
8888

8989
return blocks

0 commit comments

Comments
 (0)