Skip to content

Commit aa7d2c2

Browse files
committed
Delay large arrays
1 parent 2d98d0a commit aa7d2c2

1 file changed

Lines changed: 42 additions & 39 deletions

File tree

SimPEG/dask/electromagnetics/time_domain/simulation.py

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int):
229229
chunk_size = len(chunk)
230230

231231
# Condition to start a new block
232-
if (row_count + chunk_size) > (data_block_size * cpu_count() / 2):
232+
if (row_count + chunk_size) > (data_block_size * cpu_count()):
233233
row_count = 0
234234
block_count += 1
235235
blocks[block_count] = {}
@@ -262,9 +262,17 @@ def deriv_block(
262262
return stacked_block
263263

264264

265-
def update_deriv_blocks(address, indices, derivatives, solve):
266-
columns, local_ind = indices[address]
267-
derivatives[address][:, local_ind] = solve[:, columns]
265+
def update_deriv_blocks(address, tInd, indices, derivatives, solve, shape):
266+
if address not in derivatives:
267+
deriv_array = np.zeros(shape)
268+
else:
269+
deriv_array = derivatives[address].compute()
270+
271+
if address in indices:
272+
columns, local_ind = indices[address]
273+
deriv_array[:, local_ind] = solve[:, columns]
274+
275+
derivatives[address] = delayed(deriv_array)
268276

269277

270278
def get_field_deriv_block(
@@ -298,44 +306,41 @@ def get_field_deriv_block(
298306
local_ind,
299307
)
300308
count += len(sub_ind)
309+
deriv_comp = deriv_block(
310+
s_id,
311+
r_id,
312+
b_id,
313+
ATinv_df_duT_v,
314+
Asubdiag,
315+
local_ind,
316+
sub_ind,
317+
simulation,
318+
tInd,
319+
)
301320

302321
stacked_blocks.append(
303-
deriv_block(
304-
s_id,
305-
r_id,
306-
b_id,
307-
ATinv_df_duT_v,
308-
Asubdiag,
309-
local_ind,
310-
sub_ind,
311-
simulation,
312-
tInd,
322+
array.from_delayed(
323+
deriv_comp,
324+
dtype=float,
325+
shape=(
326+
simulation.field_derivs[tInd][s_id][r_id].shape[0],
327+
len(local_ind),
328+
),
313329
)
314330
)
315-
316331
if len(stacked_blocks) > 0:
317-
solve = AdiagTinv * np.hstack(dask.compute(stacked_blocks)[0])
332+
blocks = array.hstack(stacked_blocks).compute()
333+
solve = AdiagTinv * blocks
318334

319335
update_list = []
320-
for s_id, r_id, b_id in block:
321-
if (s_id, r_id, b_id) not in ATinv_df_duT_v:
322-
ATinv_df_duT_v[(s_id, r_id, b_id)] = np.zeros(
323-
(
324-
simulation.field_derivs[tInd][s_id][r_id].shape[0],
325-
len(block[(s_id, r_id, b_id)][0]),
326-
)
327-
)
328-
329-
if (s_id, r_id, b_id) in indices:
330-
update_list.append(
331-
update_deriv_blocks(
332-
(s_id, r_id, b_id),
333-
indices,
334-
ATinv_df_duT_v,
335-
solve,
336-
)
337-
)
338-
336+
for address in block:
337+
shape = (
338+
simulation.field_derivs[tInd][address[0]][address[1]].shape[0],
339+
len(block[address][0]),
340+
)
341+
update_list.append(
342+
update_deriv_blocks(address, tInd, indices, ATinv_df_duT_v, solve, shape)
343+
)
339344
dask.compute(update_list)
340345

341346
return ATinv_df_duT_v
@@ -395,7 +400,7 @@ def compute_J(self, f=None, Ainv=None):
395400
f, Ainv = self.fields(self.model, return_Ainv=True)
396401

397402
ftype = self._fieldType + "Solution"
398-
Jmatrix = np.zeros((self.survey.nD, self.model.size), dtype=np.float32)
403+
Jmatrix = delayed(np.zeros((self.survey.nD, self.model.size), dtype=np.float32))
399404
simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0
400405
data_times = self.survey.source_list[0].receiver_list[0].times
401406
blocks = get_parallel_blocks(
@@ -427,17 +432,15 @@ def compute_J(self, f=None, Ainv=None):
427432
time_mask,
428433
)
429434
)
430-
431435
dask.compute(j_row_updates)
432-
433436
for A in Ainv.values():
434437
A.clean()
435438

436439
if self.store_sensitivities == "disk":
437440
del Jmatrix
438441
return array.from_zarr(self.sensitivity_path + f"J.zarr")
439442
else:
440-
return Jmatrix
443+
return Jmatrix.compute()
441444

442445

443446
Sim.compute_J = compute_J

0 commit comments

Comments
 (0)