Skip to content

Commit 489775b

Browse files
authored
Merge pull request #123 from MiraGeoscience/GEOPY-2526
GEOPY-2526: Improve parallel creation of 1D simulations
2 parents 7564c09 + 284bdad commit 489775b

3 files changed

Lines changed: 24 additions & 19 deletions

File tree

simpeg/dask/electromagnetics/frequency_domain/simulation.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def compute_J(self, m, f=None):
282282
for block_derivs_chunks, addresses_chunks in zip(
283283
blocks_receiver_derivs, blocks, strict=True
284284
):
285-
Jmatrix = parallel_block_compute(
285+
parallel_block_compute(
286286
simulation,
287287
m,
288288
Jmatrix,
@@ -301,7 +301,10 @@ def compute_J(self, m, f=None):
301301
gc.collect()
302302
if self.store_sensitivities == "disk":
303303
del Jmatrix
304-
Jmatrix = array.from_zarr(self.sensitivity_path)
304+
return array.from_zarr(self.sensitivity_path)
305+
306+
if client:
307+
return client.gather(Jmatrix)
305308

306309
return Jmatrix
307310

@@ -367,11 +370,9 @@ def parallel_block_compute(
367370
count += n_cols
368371

369372
if client:
370-
client.gather(block_delayed)
373+
return client.gather(block_delayed)
371374
else:
372-
compute(block_delayed)
373-
374-
return Jmatrix
375+
return compute(block_delayed)
375376

376377

377378
Sim.compute_J = compute_J

simpeg/dask/electromagnetics/static/resistivity/simulation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,10 @@ def getSourceTerm(self):
181181
)
182182
blocks = []
183183
for ind in indices:
184+
185+
if len(ind) == 0:
186+
continue
187+
184188
blocks.append(
185189
client.submit(source_eval, sim, future_list, ind, workers=worker)
186190
)

simpeg/directives/_save_geoh5.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,20 @@ def __init__(
5252
)
5353

5454
def initialize(self):
55-
if self.open_geoh5:
56-
self._geoh5.open(mode="r+")
55+
if self.open_geoh5 and not getattr(self._workspace, "_geoh5", None):
56+
self._workspace.open(mode="r+")
5757

5858
self.write(0)
5959

6060
if self.close_geoh5:
61-
self._geoh5.close()
61+
self._workspace.close()
6262

6363
def endIter(self):
64-
if self.open_geoh5:
65-
self._geoh5.open(mode="r+")
64+
if self.open_geoh5 and not getattr(self._workspace, "_geoh5", None):
65+
self._workspace.open(mode="r+")
6666
self.write(self.opt.iter)
6767
if self.close_geoh5:
68-
self._geoh5.close()
68+
self._workspace.close()
6969

7070
def get_names(
7171
self, component: str, channel: str, iteration: int
@@ -127,7 +127,7 @@ def h5_object(self, entity: ObjectBase):
127127
)
128128

129129
self._h5_object = entity.uid
130-
self._geoh5 = entity.workspace
130+
self._workspace = entity.workspace
131131

132132
if getattr(entity, "n_cells", None) is not None:
133133
self.association = "CELL"
@@ -263,7 +263,7 @@ def write(self, iteration: int, values: list[np.ndarray] = None): # flake8: noq
263263
prop = self.apply_transformations(prop)
264264

265265
# Save results
266-
with fetch_active_workspace(self._geoh5, mode="r+") as w_s:
266+
with fetch_active_workspace(self._workspace, mode="r+") as w_s:
267267
h5_object = w_s.get_entity(self.h5_object)[0]
268268
for cc, component in enumerate(self.components):
269269
if component not in self.data_type:
@@ -386,7 +386,7 @@ def joint_index(self, value: list[int] | None):
386386
class SaveLogFilesGeoH5(BaseSaveGeoH5):
387387

388388
def write(self, iteration: int, **_):
389-
dirpath = Path(self._geoh5.h5file).parent
389+
dirpath = Path(self._workspace.h5file).parent
390390
filepath = dirpath / "SimPEG.out"
391391

392392
if iteration == 0:
@@ -412,9 +412,9 @@ def save_log(self):
412412
"""
413413
Save iteration metrics to comments.
414414
"""
415-
dirpath = Path(self._geoh5.h5file).parent
415+
dirpath = Path(self._workspace.h5file).parent
416416

417-
with fetch_active_workspace(self._geoh5, mode="r+") as w_s:
417+
with fetch_active_workspace(self._workspace, mode="r+") as w_s:
418418
h5_object = w_s.get_entity(self.h5_object)[0]
419419

420420
for file in ["SimPEG.out", "SimPEG.log", "ChiFactors.log"]:
@@ -452,7 +452,7 @@ def write(self, iteration: int, **_):
452452
"""
453453
Save the model to the geoh5 file
454454
"""
455-
with fetch_active_workspace(self._geoh5, mode="r+") as w_s:
455+
with fetch_active_workspace(self._workspace, mode="r+") as w_s:
456456
h5_object = w_s.get_entity(self.h5_object)[0]
457457

458458
for component in self.components:
@@ -560,7 +560,7 @@ def write(self, iteration: int, values: list[np.ndarray] | None = None):
560560
petro_model = self.get_values(values)
561561
petro_model = self.apply_transformations(petro_model).flatten()
562562
channel_name, _ = self.get_names("petrophysics", "", iteration)
563-
with fetch_active_workspace(self._geoh5, mode="r+") as w_s:
563+
with fetch_active_workspace(self._workspace, mode="r+") as w_s:
564564
h5_object = w_s.get_entity(self.h5_object)[0]
565565
data = h5_object.add_data(
566566
{

0 commit comments

Comments
 (0)