Skip to content

Commit 339a18d

Browse files
committed
Revert "Store tiles on driver"
This reverts commit 834bb6a. Revert "Re-assign computed list of dask tasks" This reverts commit 07fc674.
1 parent 834bb6a commit 339a18d

3 files changed

Lines changed: 31 additions & 55 deletions

File tree

simpeg_drivers/components/factories/misfit_factory.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import TYPE_CHECKING
1717

1818
import numpy as np
19-
from dask.distributed import Client, wait
19+
from dask.distributed import wait
2020
from simpeg import objective_function
2121
from simpeg.dask import objective_function as dask_objective_function
2222
from simpeg.objective_function import ComboObjectiveFunction
@@ -30,38 +30,32 @@
3030

3131

3232
class MisfitFactory(SimPEGFactory):
33-
"""
34-
Build SimPEG global misfit function.
35-
36-
:param params: Options object containing SimPEG object parameters.
37-
:param simulation: SimPEG simulation object.
38-
:param tiles: Dictionary of nested lists with arrays of indices for the tiles.
39-
:param client: Dask client or boolean to indicate whether to use dask.
40-
:param workers: List of worker addresses to use for dask computations.
41-
"""
42-
43-
def __init__(
44-
self,
45-
params,
46-
simulation,
47-
tiles: dict[list[np.ndarray]],
48-
client: Client | bool,
49-
workers: list[tuple[str]],
50-
):
33+
"""Build SimPEG global misfit function."""
34+
35+
def __init__(self, params, client, simulation, workers):
36+
"""
37+
:param params: Options object containing SimPEG object parameters.
38+
"""
5139
super().__init__(params)
5240

5341
self.simpeg_object = self.concrete_object()
42+
self.factory_type = self.params.inversion_type
5443
self.simulation = simulation
55-
self.tiles = tiles
5644
self.client = client
5745
self.workers = workers
5846

5947
def concrete_object(self):
6048
return objective_function.ComboObjectiveFunction
6149

6250
def assemble_arguments( # pylint: disable=arguments-differ
63-
self,
51+
self, tiles
6452
):
53+
# Base slice over frequencies
54+
if self.factory_type in ["magnetotellurics", "tipper", "fdem"]:
55+
channels = self.simulation.survey.frequencies
56+
else:
57+
channels = [None]
58+
6559
use_futures = self.client
6660

6761
# Pickle the simulation to the temporary file
@@ -72,9 +66,8 @@ def assemble_arguments( # pylint: disable=arguments-differ
7266

7367
misfits = []
7468
tile_count = 0
75-
for channel, tiles in self.tiles.items():
69+
for channel in channels:
7670
for local_indices in tiles:
77-
# Split again but use the same mesh extent based on tile vertices
7871
for sub_ind in local_indices:
7972
if len(sub_ind) == 0:
8073
continue
@@ -124,10 +117,10 @@ def assemble_arguments( # pylint: disable=arguments-differ
124117
def assemble_keyword_arguments(self, **_):
125118
"""Implementation of abstract method from SimPEGFactory."""
126119

127-
def build(self, **_):
120+
def build(self, tiles, **_):
128121
"""To be over-ridden in factory implementations."""
129122

130-
misfits = self.assemble_arguments()
123+
misfits = self.assemble_arguments(tiles)
131124

132125
if self.client:
133126
return dask_objective_function.DistributedComboMisfits(

simpeg_drivers/driver.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __init__(
103103
"Disk storage of sensitivities is not compatible with distributed processing."
104104
)
105105

106-
self._workers: list[tuple[str]] = self.validate_workers(workers)
106+
self._workers: list[tuple[str]] | None = self.validate_workers(workers)
107107

108108
@property
109109
def out_group(self) -> SimPEGGroup:
@@ -142,7 +142,7 @@ def validate_out_group(self, out_group: SimPEGGroup | None) -> SimPEGGroup:
142142
return out_group
143143

144144
@property
145-
def client(self) -> Client | bool:
145+
def client(self) -> Client | bool | None:
146146
"""
147147
Dask client or False if not using Dask.distributed.
148148
"""
@@ -279,7 +279,6 @@ def __init__(
279279
self._ordering: list[np.ndarray] | None = None
280280
self._mappings: list[maps.IdentityMap] | None = None
281281
self._window = None
282-
self.tiles: dict[list[np.ndarray]]
283282

284283
def split_list(self, tiles: list[np.ndarray]) -> list[np.ndarray]:
285284
"""
@@ -310,9 +309,7 @@ def split_list(self, tiles: list[np.ndarray]) -> list[np.ndarray]:
310309

311310
flat_tile_list = []
312311
for tile, split in zip(tiles, split_list):
313-
flat_tile_list.append(
314-
sub for sub in np.array_split(tile, split) if len(sub) > 0
315-
)
312+
flat_tile_list.append(np.array_split(tile, split))
316313
return flat_tile_list
317314

318315
@property
@@ -321,19 +318,16 @@ def data_misfit(self):
321318
if getattr(self, "_data_misfit", None) is None:
322319
with fetch_active_workspace(self.workspace, mode="r+"):
323320
# Tile locations
324-
if self.logger and self.params.compute.tile_spatial > 1:
325-
self.logger.write(
326-
f"Setting up {self.params.compute.tile_spatial} tiles . . .\n"
327-
)
321+
tiles = self.get_tiles()
322+
323+
if self.logger:
324+
self.logger.write(f"Setting up {len(tiles)} tile(s) . . .\n")
328325

329-
self.tiles = self.get_tiles()
330326
self._data_misfit = MisfitFactory(
331-
self.params,
332-
self.simulation,
333-
self.tiles,
334-
client=self.client,
335-
workers=self.workers,
336-
).build()
327+
self.params, self.client, self.simulation, self.workers
328+
).build(
329+
self.split_list(tiles),
330+
)
337331

338332
return self._data_misfit
339333

@@ -782,24 +776,13 @@ def get_tiles(self):
782776

783777
return np.array_split(indices, n_chunks)
784778

785-
tiles = tile_locations(
779+
return tile_locations(
786780
self.inversion_data.locations,
787781
self.params.compute.tile_spatial,
788782
labels=self.inversion_data.parts,
789783
sorting=self.simulation.survey.sorting,
790784
)
791785

792-
self.split_list(tiles)
793-
794-
# Base slice over frequencies
795-
if self.params.inversion_type in ["magnetotellurics", "tipper", "fdem"]:
796-
channels = self.simulation.survey.frequencies
797-
else:
798-
channels = [None]
799-
800-
# Duplicate tiles for each channel
801-
return {channel: tiles for channel in channels}
802-
803786
@classmethod
804787
def start(cls, filepath: str | Path | InputFile, **kwargs) -> Self:
805788
"""

simpeg_drivers/plate_simulation/match/driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def run(self):
223223
# Display progress bar
224224
if isinstance(tasks[0], Future):
225225
progress(tasks)
226-
tasks = self.client.gather(tasks)
226+
self.client.gather(tasks)
227227

228228
scores = np.hstack(tasks)
229229
ranked = np.argsort(scores)[::-1]

0 commit comments

Comments
 (0)