Skip to content

Commit cea15d1

Browse files
committed
Fix disk process
1 parent 2ff0dcc commit cea15d1

1 file changed

Lines changed: 69 additions & 71 deletions

File tree

SimPEG/dask/electromagnetics/time_domain/simulation.py

Lines changed: 69 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def evaluate_receiver(source, receiver, mesh, time_mesh, fields):
105105

106106
Sim.dpred = dask_dpred
107107
Sim.field_derivs = None
108+
Sim.j_initialzer = None
109+
108110

109111
def compute_J(self, f=None, Ainv=None):
110112

@@ -115,50 +117,33 @@ def compute_J(self, f=None, Ainv=None):
115117
row_chunks = int(np.ceil(
116118
float(self.survey.nD) / np.ceil(float(m_size) * self.survey.nD * 8. * 1e-6 / self.max_chunk_size)
117119
))
120+
solution_type = self._fieldType + "Solution" # the thing we solved for
118121

119122
if self.store_sensitivities == "disk":
120-
self.J_initializer = zarr.open(
121-
self.sensitivity_path + f"J_initializer.zarr",
123+
Jmatrix = zarr.open(
124+
self.sensitivity_path + f"J.zarr",
122125
mode='w',
123126
shape=(self.survey.nD, m_size),
124127
chunks=(row_chunks, m_size)
125-
)
128+
)# + J_initializer
126129
else:
127-
self.J_initializer = np.zeros((self.survey.nD, m_size), dtype=np.float32)
128-
solution_type = self._fieldType + "Solution" # the thing we solved for
130+
Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32)
129131

130132
if self.field_derivs is None:
131-
132-
# print("Start loop for field derivs")
133133
block_size = len(f[self.survey.source_list[0], solution_type, 0])
134-
135134
field_derivs = []
135+
136136
for tInd in range(self.nT + 1):
137137
d_count = 0
138138
df_duT_v = []
139139
for i_s, src in enumerate(self.survey.source_list):
140-
src_field_derivs = delayed(block_deriv, pure=True)(self, src, tInd, f, block_size, d_count)
140+
src_field_derivs = delayed(block_deriv, pure=True)(self, src, tInd, f, block_size)
141141
df_duT_v += [src_field_derivs]
142142
d_count += np.sum([rx.nD for rx in src.receiver_list])
143143

144144
field_derivs += [df_duT_v]
145-
# print("Dask loop field derivs")
146-
# tc = time()
147-
148145
self.field_derivs = dask.compute(field_derivs)[0]
149-
# print(f"Done in {time() - tc} seconds")
150146

151-
if self.store_sensitivities == "disk":
152-
Jmatrix = zarr.open(
153-
self.sensitivity_path + f"J.zarr",
154-
mode='w',
155-
shape=(self.survey.nD, m_size),
156-
chunks=(row_chunks, m_size)
157-
) + self.J_initializer
158-
else:
159-
Jmatrix = dask.delayed(np.zeros((self.survey.nD, m_size), dtype=np.float32) + self.J_initializer)
160-
161-
# ATinv_df_duT_v = {}
162147
f = dask.delayed(f)
163148
field_derivs_t = {}
164149

@@ -167,59 +152,76 @@ def compute_J(self, f=None, Ainv=None):
167152
AdiagTinv = Ainv[dt]
168153
Asubdiag = self.getAsubdiag(tInd)
169154
d_count = 0
155+
field_deriv_blocks = []
170156
row_blocks = []
171157

172-
# tc_loop = time()
173-
# print(f"Loop sources for {tInd}")
174158
for isrc, src in enumerate(self.survey.source_list):
175159
source_blocks = []
176-
# for block in range(len(self.field_derivs[tInd][isrc])):
177-
if isrc not in field_derivs_t:
178-
ATinv_df_duT_v = dask.delayed(AdiagTinv * self.field_derivs[tInd+1][isrc].toarray())
179-
else:
180-
ATinv_df_duT_v = dask.delayed(AdiagTinv * field_derivs_t[isrc])
181-
182-
n_data = self.field_derivs[tInd+1][isrc].shape[1]
160+
n_data = self.field_derivs[tInd+1][isrc][0].shape[1]
183161
n_blocks = int(np.ceil((m_size * n_data) * 8. * 1e-6 / 128.))
184-
ind_col = np.array_split(np.arange(n_data), n_blocks)
162+
sub_blocks = np.array_split(np.arange(n_data), n_blocks)
163+
164+
for block_ind in sub_blocks:
165+
if isrc not in field_derivs_t:
166+
ATinv_df_duT_v = (
167+
AdiagTinv * self.field_derivs[tInd + 1][isrc][0][:, block_ind].toarray()
168+
)
169+
else:
170+
ATinv_df_duT_v = AdiagTinv * np.asarray(field_derivs_t[isrc][:, block_ind])
171+
172+
delayed_J_block = delayed(parallel_block_compute, pure=True)(
173+
self, f, src, ATinv_df_duT_v,
174+
tInd, solution_type, d_count, Jmatrix, self.field_derivs[tInd+1][isrc][1][block_ind, :]
175+
)
176+
177+
delayed_field_block = delayed(parallel_field_deriv, pure=True)(
178+
ATinv_df_duT_v, Asubdiag, self.field_derivs[tInd][isrc][0][:, block_ind]
179+
)
185180

186-
for col_block in ind_col:
187181
source_blocks.append(
188182
dask.array.from_delayed(
189-
delayed(parallel_block_compute, pure=True)(
190-
self, f, src, ATinv_df_duT_v, d_count,
191-
col_block, tInd, solution_type, Jmatrix, Asubdiag,
192-
self.field_derivs[tInd][isrc]
193-
),
194-
shape=self.field_derivs[tInd+1][isrc].shape,
183+
delayed_field_block,
184+
shape=(Asubdiag.shape[0], len(block_ind)),
195185
dtype=np.float32
196186
)
197187
)
188+
189+
row_blocks.append(dask.array.from_delayed(
190+
delayed_J_block,
191+
shape=(len(block_ind), m_size),
192+
dtype=np.float32
193+
))
198194
# print(f"Appending block {isrc} in {time() - tc} seconds")
199-
d_count += len(col_block)
195+
d_count += len(block_ind)
196+
197+
field_deriv_blocks.append(dask.array.hstack(source_blocks))
198+
199+
if self.store_sensitivities == "disk":
200+
Jmatrix.set_orthogonal_selection(
201+
(np.arange(self.survey.nD), slice(None)),
202+
Jmatrix + dask.array.vstack(row_blocks).astype(np.float32)
203+
)
204+
else:
205+
dask.compute(row_blocks)
200206

201-
row_blocks.append(dask.array.hstack(source_blocks))
202-
# print(f"Done in {time() - tc_loop} seconds")
203-
# tc = time()
204-
# print(f"Compute field derivs for {tInd}")
205207
del field_derivs_t
206-
field_derivs_t = {isrc: elem for isrc, elem in enumerate(dask.compute(row_blocks)[0])}
207-
# print(f"Done in {time() - tc} seconds")
208+
field_derivs_t = {isrc: elem for isrc, elem in enumerate(dask.compute(field_deriv_blocks)[0])}
208209

209210
for A in Ainv.values():
210211
A.clean()
211212

212213
if self.store_sensitivities == "disk":
213214
del Jmatrix
214215
return array.from_zarr(self.sensitivity_path + f"J.zarr")
215-
else:
216-
return Jmatrix.compute()
216+
217+
return Jmatrix
217218

218219
Sim.compute_J = compute_J
219220

220221

221-
def block_deriv(simulation, src, tInd, f, block_size, d_count):
222-
src_field_derivs = None
222+
def block_deriv(simulation, src, tInd, f, block_size):
223+
src_field_derivs = []
224+
j_initial = []
223225
for rx in src.receiver_list:
224226

225227
v = sp.eye(rx.nD, dtype=float)
@@ -235,37 +237,33 @@ def block_deriv(simulation, src, tInd, f, block_size, d_count):
235237
PT_v[tInd * block_size:(tInd + 1) * block_size, :],
236238
adjoint=True,
237239
)
238-
239-
if not isinstance(cur[1], Zero):
240-
simulation.J_initializer[d_count:d_count + rx.nD, :] += cur[1].T
241-
242-
if src_field_derivs is None:
243-
src_field_derivs = cur[0]
244-
else:
245-
src_field_derivs += cur[0]
240+
src_field_derivs.append(cur[0])
241+
j_initial.append(cur[1].T)
246242

247243
# n_blocks = int(np.ceil(np.prod(src_field_derivs.shape) * 8. * 1e-6 / 128.))
248244
# ind_col = np.array_split(np.arange(src_field_derivs.shape[1]), col_blocks)
249245
# return [src_field_derivs[:, ind] for ind in ind_col]
250-
return src_field_derivs
246+
return sp.hstack(src_field_derivs), sp.vstack(j_initial)
251247

252-
def parallel_block_compute(simulation, f, src, ATinv_df_duT_v, d_count, col_block, tInd, solution_type, Jmatrix, Asubdiag, field_derivs):
253-
field_derivs_t = np.asarray(
254-
field_derivs[:, col_block]
255-
- Asubdiag.T * ATinv_df_duT_v[:, col_block]
256-
)
257248

249+
def parallel_field_deriv(ATinv_df_duT_v, Asubdiag, field_derivs):
250+
return field_derivs - Asubdiag.T * ATinv_df_duT_v
251+
252+
253+
def parallel_block_compute(simulation, f, src, ATinv_df_duT_v, tInd, solution_type, d_count, Jmatrix, j_initial):
258254
dAsubdiagT_dm_v = simulation.getAsubdiagDeriv(
259-
tInd, f[src, solution_type, tInd], ATinv_df_duT_v[:, col_block], adjoint=True
255+
tInd, f[src, solution_type, tInd], ATinv_df_duT_v, adjoint=True
260256
)
261257

262258
dRHST_dm_v = simulation.getRHSDeriv(
263-
tInd + 1, src, ATinv_df_duT_v[:, col_block], adjoint=True
259+
tInd + 1, src, ATinv_df_duT_v, adjoint=True
264260
)
265261
un_src = f[src, solution_type, tInd + 1]
266262
dAT_dm_v = simulation.getAdiagDeriv(
267-
tInd, un_src, ATinv_df_duT_v[:, col_block], adjoint=True
263+
tInd, un_src, ATinv_df_duT_v, adjoint=True
268264
)
269-
Jmatrix[d_count:d_count + dAT_dm_v.shape[1], :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T
270265

271-
return field_derivs_t
266+
if simulation.store_sensitivities == "disk":
267+
return (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T + j_initial
268+
269+
Jmatrix[d_count:d_count + dAT_dm_v.shape[1], :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T + j_initial

0 commit comments

Comments
 (0)