Skip to content

Commit e138c0d

Browse files
committed
Re-work of block compute
1 parent 45a77df commit e138c0d

1 file changed

Lines changed: 87 additions & 40 deletions

File tree

SimPEG/dask/electromagnetics/time_domain/simulation.py

Lines changed: 87 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from ....electromagnetics.time_domain.simulation import BaseTDEMSimulation as Sim
55
from ....utils import Zero
6+
from multiprocessing import cpu_count
67
import numpy as np
78
import scipy.sparse as sp
89
from time import time
@@ -105,7 +106,7 @@ def evaluate_receiver(source, receiver, mesh, time_mesh, fields):
105106

106107
Sim.dpred = dask_dpred
107108
Sim.field_derivs = None
108-
Sim.j_initialzer = None
109+
109110

110111

111112
def compute_J(self, f=None, Ainv=None):
@@ -117,6 +118,7 @@ def compute_J(self, f=None, Ainv=None):
117118
row_chunks = int(np.ceil(
118119
float(self.survey.nD) / np.ceil(float(m_size) * self.survey.nD * 8. * 1e-6 / self.max_chunk_size)
119120
))
121+
120122
solution_type = self._fieldType + "Solution" # the thing we solved for
121123

122124
if self.store_sensitivities == "disk":
@@ -152,60 +154,67 @@ def compute_J(self, f=None, Ainv=None):
152154
self.field_derivs = dask.compute(field_derivs)[0]
153155

154156
f = dask.delayed(f)
155-
field_derivatives = {}
157+
field_derivatives = None
158+
batch_map = {}
156159

157160
for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))):
158161

159162
AdiagTinv = Ainv[dt]
160163
Asubdiag = self.getAsubdiag(tInd)
161164
d_count = 0
165+
block_count = 0
162166
field_deriv_blocks = []
163167
j_row_blocks = []
164-
168+
count = 0
169+
batch_block = []
170+
batch_indices = []
171+
batch_count = 0
165172
for isrc, src in enumerate(self.survey.source_list):
166173
field_blocks = []
167174
n_data = self.field_derivs[tInd+1][isrc][0].shape[1]
168175
n_blocks = int(np.ceil((m_size * n_data) * 8. * 1e-6 / 128.))
169176
sub_blocks = np.array_split(np.arange(n_data), n_blocks)
170177

171-
for block_ind in sub_blocks:
172-
if isrc not in field_derivatives:
173-
ATinv_df_duT_v = (
174-
AdiagTinv * self.field_derivs[tInd + 1][isrc][0][:, block_ind].toarray()
175-
)
176-
else:
177-
ATinv_df_duT_v = AdiagTinv * np.asarray(field_derivatives[isrc][:, block_ind])
178+
for i_block, block_ind in enumerate(sub_blocks):
178179

179-
if self.store_sensitivities == "disk":
180-
partial_derivs.set_orthogonal_selection(
181-
(slice(None), slice(d_count, d_count + len(block_ind))),
182-
ATinv_df_duT_v
183-
)
180+
if field_derivatives is None:
181+
batch_block.append(self.field_derivs[tInd + 1][isrc][0][:, block_ind].toarray())
182+
batch_map[isrc, i_block] = (batch_count, count)
184183
else:
185-
partial_derivs[:, d_count: d_count + len(block_ind)] = ATinv_df_duT_v
186-
187-
field_blocks.append(
188-
dask.array.from_delayed(
189-
delayed(parallel_field_deriv, pure=True)(
190-
partial_derivs[:, d_count: d_count + len(block_ind)], Asubdiag,
191-
self.field_derivs[tInd][isrc][0][:, block_ind]
192-
),
193-
shape=(Asubdiag.shape[0], len(block_ind)),
194-
dtype=np.float64
195-
)
196-
)
197-
j_row_blocks.append(dask.array.from_delayed(
198-
delayed(parallel_block_compute, pure=True)(
199-
self, f, src, partial_derivs[:, d_count: d_count + len(block_ind)],
200-
tInd, solution_type, d_count, Jmatrix, self.field_derivs[tInd + 1][isrc][1][block_ind, :]
201-
),
202-
shape=(len(block_ind), m_size),
203-
dtype=np.float32
204-
))
205-
d_count += len(block_ind)
206-
207-
field_deriv_blocks.append(dask.array.hstack(field_blocks))
184+
i_file, i_block = batch_map[isrc, i_block]
185+
batch_block.append(field_derivatives[i_file][:, i_block:i_block + len(block_ind)])
186+
187+
batch_indices.append((isrc, block_ind))
188+
block_count += 1
208189

190+
if block_count >= cpu_count():
191+
f_blocks, j_blocks = process_blocks(
192+
self, AdiagTinv, d_count, batch_block, batch_indices, Asubdiag, f, tInd,
193+
solution_type, Jmatrix
194+
)
195+
field_deriv_blocks.append(dask.array.hstack(f_blocks))
196+
j_row_blocks.append(j_blocks)
197+
198+
batch_block, batch_indices = [], []
199+
block_count = 0
200+
batch_count += 1
201+
d_count += count
202+
count = 0
203+
204+
count += len(block_ind)
205+
# if isrc not in field_derivatives:
206+
# ATinv_df_duT_v = (
207+
# AdiagTinv * self.field_derivs[tInd + 1][isrc][0][:, block_ind].toarray()
208+
# )
209+
# else:
210+
# ATinv_df_duT_v = AdiagTinv * np.asarray(field_derivatives[isrc][:, block_ind])
211+
212+
f_blocks, j_blocks = process_blocks(
213+
self, AdiagTinv, d_count, batch_block, batch_indices, Asubdiag, f, tInd,
214+
solution_type, Jmatrix
215+
)
216+
field_deriv_blocks.append(dask.array.hstack(f_blocks))
217+
j_row_blocks.append(j_blocks)
209218
del field_derivatives
210219

211220
if self.store_sensitivities == "disk":
@@ -224,8 +233,6 @@ def compute_J(self, f=None, Ainv=None):
224233
dask.compute(j_row_blocks)
225234
field_derivatives = dask.compute(field_deriv_blocks)[0]
226235

227-
field_derivatives = {isrc: elem for isrc, elem in enumerate(field_derivatives)}
228-
229236
for A in Ainv.values():
230237
A.clean()
231238

@@ -238,6 +245,46 @@ def compute_J(self, f=None, Ainv=None):
238245
Sim.compute_J = compute_J
239246

240247

248+
def process_blocks(
249+
self, AdiagTinv, d_count, batch_block, batch_indices, Asubdiag, f, tInd,
250+
solution_type, Jmatrix
251+
):
252+
ATinv_df_duT_v = AdiagTinv * np.hstack(batch_block)
253+
field_blocks = []
254+
j_row_blocks = []
255+
count = 0
256+
for block, indices in zip(batch_block, batch_indices):
257+
block_size = block.shape[1]
258+
field_blocks.append(
259+
dask.array.from_delayed(
260+
delayed(parallel_field_deriv, pure=True)(
261+
ATinv_df_duT_v[:, count: count + block_size], Asubdiag,
262+
self.field_derivs[tInd][indices[0]][0][:, indices[1]]
263+
),
264+
shape=(Asubdiag.shape[0], block_size),
265+
dtype=np.float64
266+
)
267+
)
268+
j_row_blocks.append(dask.array.from_delayed(
269+
delayed(parallel_block_compute, pure=True)(
270+
self, f,
271+
self.survey.source_list[indices[0]],
272+
ATinv_df_duT_v[:, count: count + block_size],
273+
tInd,
274+
solution_type,
275+
d_count,
276+
Jmatrix,
277+
self.field_derivs[tInd + 1][indices[0]][1][indices[1], :]
278+
),
279+
shape=(block_size, Jmatrix.shape[1]),
280+
dtype=np.float32
281+
))
282+
count += block_size
283+
d_count += block_size
284+
285+
return field_blocks, j_row_blocks
286+
287+
241288
def block_deriv(simulation, src, tInd, f, block_size):
242289
src_field_derivs = []
243290
j_initial = []

0 commit comments

Comments
 (0)