Skip to content

Commit 834bb6a

Browse files
committed
Store tiles on driver
1 parent 07fc674 commit 834bb6a

2 files changed

Lines changed: 54 additions & 30 deletions

File tree

simpeg_drivers/components/factories/misfit_factory.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import TYPE_CHECKING
1717

1818
import numpy as np
19-
from dask.distributed import wait
19+
from dask.distributed import Client, wait
2020
from simpeg import objective_function
2121
from simpeg.dask import objective_function as dask_objective_function
2222
from simpeg.objective_function import ComboObjectiveFunction
@@ -30,32 +30,38 @@
3030

3131

3232
class MisfitFactory(SimPEGFactory):
33-
"""Build SimPEG global misfit function."""
34-
35-
def __init__(self, params, client, simulation, workers):
36-
"""
37-
:param params: Options object containing SimPEG object parameters.
38-
"""
33+
"""
34+
Build SimPEG global misfit function.
35+
36+
:param params: Options object containing SimPEG object parameters.
37+
:param simulation: SimPEG simulation object.
38+
:param tiles: Dictionary of nested lists with arrays of indices for the tiles.
39+
:param client: Dask client or boolean to indicate whether to use dask.
40+
:param workers: List of worker addresses to use for dask computations.
41+
"""
42+
43+
def __init__(
44+
self,
45+
params,
46+
simulation,
47+
tiles: dict[list[np.ndarray]],
48+
client: Client | bool,
49+
workers: list[tuple[str]],
50+
):
3951
super().__init__(params)
4052

4153
self.simpeg_object = self.concrete_object()
42-
self.factory_type = self.params.inversion_type
4354
self.simulation = simulation
55+
self.tiles = tiles
4456
self.client = client
4557
self.workers = workers
4658

4759
def concrete_object(self):
4860
return objective_function.ComboObjectiveFunction
4961

5062
def assemble_arguments( # pylint: disable=arguments-differ
51-
self, tiles
63+
self,
5264
):
53-
# Base slice over frequencies
54-
if self.factory_type in ["magnetotellurics", "tipper", "fdem"]:
55-
channels = self.simulation.survey.frequencies
56-
else:
57-
channels = [None]
58-
5965
use_futures = self.client
6066

6167
# Pickle the simulation to the temporary file
@@ -66,8 +72,9 @@ def assemble_arguments( # pylint: disable=arguments-differ
6672

6773
misfits = []
6874
tile_count = 0
69-
for channel in channels:
75+
for channel, tiles in self.tiles.items():
7076
for local_indices in tiles:
77+
# Split again but use the same mesh extent based on tile vertices
7178
for sub_ind in local_indices:
7279
if len(sub_ind) == 0:
7380
continue
@@ -117,10 +124,10 @@ def assemble_arguments( # pylint: disable=arguments-differ
117124
def assemble_keyword_arguments(self, **_):
118125
"""Implementation of abstract method from SimPEGFactory."""
119126

120-
def build(self, tiles, **_):
127+
def build(self, **_):
121128
"""To be over-ridden in factory implementations."""
122129

123-
misfits = self.assemble_arguments(tiles)
130+
misfits = self.assemble_arguments()
124131

125132
if self.client:
126133
return dask_objective_function.DistributedComboMisfits(

simpeg_drivers/driver.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __init__(
103103
"Disk storage of sensitivities is not compatible with distributed processing."
104104
)
105105

106-
self._workers: list[tuple[str]] | None = self.validate_workers(workers)
106+
self._workers: list[tuple[str]] = self.validate_workers(workers)
107107

108108
@property
109109
def out_group(self) -> SimPEGGroup:
@@ -142,7 +142,7 @@ def validate_out_group(self, out_group: SimPEGGroup | None) -> SimPEGGroup:
142142
return out_group
143143

144144
@property
145-
def client(self) -> Client | bool | None:
145+
def client(self) -> Client | bool:
146146
"""
147147
Dask client or False if not using Dask.distributed.
148148
"""
@@ -279,6 +279,7 @@ def __init__(
279279
self._ordering: list[np.ndarray] | None = None
280280
self._mappings: list[maps.IdentityMap] | None = None
281281
self._window = None
282+
self.tiles: dict[list[np.ndarray]]
282283

283284
def split_list(self, tiles: list[np.ndarray]) -> list[np.ndarray]:
284285
"""
@@ -309,7 +310,9 @@ def split_list(self, tiles: list[np.ndarray]) -> list[np.ndarray]:
309310

310311
flat_tile_list = []
311312
for tile, split in zip(tiles, split_list):
312-
flat_tile_list.append(np.array_split(tile, split))
313+
flat_tile_list.append(
314+
sub for sub in np.array_split(tile, split) if len(sub) > 0
315+
)
313316
return flat_tile_list
314317

315318
@property
@@ -318,16 +321,19 @@ def data_misfit(self):
318321
if getattr(self, "_data_misfit", None) is None:
319322
with fetch_active_workspace(self.workspace, mode="r+"):
320323
# Tile locations
321-
tiles = self.get_tiles()
322-
323-
if self.logger:
324-
self.logger.write(f"Setting up {len(tiles)} tile(s) . . .\n")
324+
if self.logger and self.params.compute.tile_spatial > 1:
325+
self.logger.write(
326+
f"Setting up {self.params.compute.tile_spatial} tiles . . .\n"
327+
)
325328

329+
self.tiles = self.get_tiles()
326330
self._data_misfit = MisfitFactory(
327-
self.params, self.client, self.simulation, self.workers
328-
).build(
329-
self.split_list(tiles),
330-
)
331+
self.params,
332+
self.simulation,
333+
self.tiles,
334+
client=self.client,
335+
workers=self.workers,
336+
).build()
331337

332338
return self._data_misfit
333339

@@ -776,13 +782,24 @@ def get_tiles(self):
776782

777783
return np.array_split(indices, n_chunks)
778784

779-
return tile_locations(
785+
tiles = tile_locations(
780786
self.inversion_data.locations,
781787
self.params.compute.tile_spatial,
782788
labels=self.inversion_data.parts,
783789
sorting=self.simulation.survey.sorting,
784790
)
785791

792+
self.split_list(tiles)
793+
794+
# Base slice over frequencies
795+
if self.params.inversion_type in ["magnetotellurics", "tipper", "fdem"]:
796+
channels = self.simulation.survey.frequencies
797+
else:
798+
channels = [None]
799+
800+
# Duplicate tiles for each channel
801+
return {channel: tiles for channel in channels}
802+
786803
@classmethod
787804
def start(cls, filepath: str | Path | InputFile, **kwargs) -> Self:
788805
"""

0 commit comments

Comments
 (0)