Skip to content

Commit 45a77df

Browse files
committed
zarr the field derivatives if on disk
1 parent cea15d1 commit 45a77df

1 file changed

Lines changed: 45 additions & 26 deletions

File tree

SimPEG/dask/electromagnetics/time_domain/simulation.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,16 @@ def compute_J(self, f=None, Ainv=None):
125125
mode='w',
126126
shape=(self.survey.nD, m_size),
127127
chunks=(row_chunks, m_size)
128-
)# + J_initializer
128+
)
129+
partial_derivs = zarr.open(
130+
self.sensitivity_path + f"partials.zarr",
131+
mode='w',
132+
shape=(self.getAsubdiag(0).shape[0], self.survey.nD),
133+
chunks=(self.getAsubdiag(0).shape[0], row_chunks)
134+
)
129135
else:
130136
Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32)
137+
partial_derivs = np.zeros((self.getAsubdiag(0).shape[0], self.survey.nD), dtype=np.float32)
131138

132139
if self.field_derivs is None:
133140
block_size = len(f[self.survey.source_list[0], solution_type, 0])
@@ -145,67 +152,79 @@ def compute_J(self, f=None, Ainv=None):
145152
self.field_derivs = dask.compute(field_derivs)[0]
146153

147154
f = dask.delayed(f)
148-
field_derivs_t = {}
155+
field_derivatives = {}
149156

150157
for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))):
151158

152159
AdiagTinv = Ainv[dt]
153160
Asubdiag = self.getAsubdiag(tInd)
154161
d_count = 0
155162
field_deriv_blocks = []
156-
row_blocks = []
163+
j_row_blocks = []
157164

158165
for isrc, src in enumerate(self.survey.source_list):
159-
source_blocks = []
166+
field_blocks = []
160167
n_data = self.field_derivs[tInd+1][isrc][0].shape[1]
161168
n_blocks = int(np.ceil((m_size * n_data) * 8. * 1e-6 / 128.))
162169
sub_blocks = np.array_split(np.arange(n_data), n_blocks)
163170

164171
for block_ind in sub_blocks:
165-
if isrc not in field_derivs_t:
172+
if isrc not in field_derivatives:
166173
ATinv_df_duT_v = (
167174
AdiagTinv * self.field_derivs[tInd + 1][isrc][0][:, block_ind].toarray()
168175
)
169176
else:
170-
ATinv_df_duT_v = AdiagTinv * np.asarray(field_derivs_t[isrc][:, block_ind])
171-
172-
delayed_J_block = delayed(parallel_block_compute, pure=True)(
173-
self, f, src, ATinv_df_duT_v,
174-
tInd, solution_type, d_count, Jmatrix, self.field_derivs[tInd+1][isrc][1][block_ind, :]
175-
)
177+
ATinv_df_duT_v = AdiagTinv * np.asarray(field_derivatives[isrc][:, block_ind])
176178

177-
delayed_field_block = delayed(parallel_field_deriv, pure=True)(
178-
ATinv_df_duT_v, Asubdiag, self.field_derivs[tInd][isrc][0][:, block_ind]
179-
)
179+
if self.store_sensitivities == "disk":
180+
partial_derivs.set_orthogonal_selection(
181+
(slice(None), slice(d_count, d_count + len(block_ind))),
182+
ATinv_df_duT_v
183+
)
184+
else:
185+
partial_derivs[:, d_count: d_count + len(block_ind)] = ATinv_df_duT_v
180186

181-
source_blocks.append(
187+
field_blocks.append(
182188
dask.array.from_delayed(
183-
delayed_field_block,
189+
delayed(parallel_field_deriv, pure=True)(
190+
partial_derivs[:, d_count: d_count + len(block_ind)], Asubdiag,
191+
self.field_derivs[tInd][isrc][0][:, block_ind]
192+
),
184193
shape=(Asubdiag.shape[0], len(block_ind)),
185-
dtype=np.float32
194+
dtype=np.float64
186195
)
187196
)
188-
189-
row_blocks.append(dask.array.from_delayed(
190-
delayed_J_block,
197+
j_row_blocks.append(dask.array.from_delayed(
198+
delayed(parallel_block_compute, pure=True)(
199+
self, f, src, partial_derivs[:, d_count: d_count + len(block_ind)],
200+
tInd, solution_type, d_count, Jmatrix, self.field_derivs[tInd + 1][isrc][1][block_ind, :]
201+
),
191202
shape=(len(block_ind), m_size),
192203
dtype=np.float32
193204
))
194-
# print(f"Appending block {isrc} in {time() - tc} seconds")
195205
d_count += len(block_ind)
196206

197-
field_deriv_blocks.append(dask.array.hstack(source_blocks))
207+
field_deriv_blocks.append(dask.array.hstack(field_blocks))
208+
209+
del field_derivatives
198210

199211
if self.store_sensitivities == "disk":
200212
Jmatrix.set_orthogonal_selection(
201213
(np.arange(self.survey.nD), slice(None)),
202-
Jmatrix + dask.array.vstack(row_blocks).astype(np.float32)
214+
Jmatrix + dask.array.vstack(j_row_blocks).astype(np.float32)
203215
)
216+
field_derivatives = [
217+
dask.array.to_zarr(
218+
field_deriv_blocks[i], self.sensitivity_path + f"field_derivs_{i}.zarr",
219+
overwrite=True,
220+
return_stored = True,
221+
) for i in range(len(field_deriv_blocks))
222+
]
204223
else:
205-
dask.compute(row_blocks)
224+
dask.compute(j_row_blocks)
225+
field_derivatives = dask.compute(field_deriv_blocks)[0]
206226

207-
del field_derivs_t
208-
field_derivs_t = {isrc: elem for isrc, elem in enumerate(dask.compute(field_deriv_blocks)[0])}
227+
field_derivatives = {isrc: elem for isrc, elem in enumerate(field_derivatives)}
209228

210229
for A in Ainv.values():
211230
A.clean()

0 commit comments

Comments
 (0)