Skip to content

Commit 8ae9c50

Browse files
committed
MOve different tiling strategies to sub classes
1 parent f4d1e3d commit 8ae9c50

3 files changed

Lines changed: 111 additions & 61 deletions

File tree

simpeg_drivers/driver.py

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -806,35 +806,13 @@ def get_tiles(self) -> dict[str, list[np.ndarray]]:
806806
807807
:return: Dictionary with channels as keys and list of tiles as values.
808808
"""
809-
n_data = self.inversion_data.mask.sum()
810-
indices = np.arange(n_data)
811-
812-
# Split tiles based on inversion type
813-
if "1d" in self.params.inversion_type:
814-
# Heuristic to avoid too many chunks
815-
n_chunks = n_data // self.params.compute.max_chunk_size
816-
817-
if self.workers:
818-
n_chunks /= len(self.workers)
819-
n_chunks = int(n_chunks) * len(self.workers)
820-
821-
n_chunks = np.max([n_chunks, 1, len(self.workers)])
822-
tiles = [[tile] for tile in np.array_split(indices, n_chunks)]
823-
# Split per line for 2D inversions
824-
elif "2d" in self.params.inversion_type:
825-
tiles = [
826-
[np.where(self.params.line_selection.line_object.values == line_id)[0]]
827-
for line_id in np.unique(self.params.line_selection.line_object.values)
828-
]
829-
# Kmeans split with subsequent splitting to optimize load
830-
else:
831-
tiles = tile_locations(
832-
self.inversion_data.locations,
833-
self.params.compute.tile_spatial,
834-
labels=self.inversion_data.parts,
835-
sorting=self.simulation.survey.sorting,
836-
)
837-
tiles = self.split_list(tiles)
809+
tiles = tile_locations(
810+
self.inversion_data.locations,
811+
self.params.compute.tile_spatial,
812+
labels=self.inversion_data.parts,
813+
sorting=self.simulation.survey.sorting,
814+
)
815+
tiles = self.split_list(tiles)
838816

839817
# Base slice over frequencies
840818
if self.params.inversion_type in [

simpeg_drivers/electricals/base_2d.py

Lines changed: 90 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from logging import getLogger
1515

16+
import numpy as np
1617
from geoh5py.objects import DrapeModel, Octree, PotentialElectrode
1718
from geoh5py.ui_json.ui_json import fetch_active_workspace
1819
from pydantic import field_validator, model_validator
@@ -24,42 +25,15 @@
2425
DrapeModelOptions,
2526
LineSelectionOptions,
2627
)
27-
from simpeg_drivers.utils.surveys import create_mesh_by_line_id
28+
from simpeg_drivers.utils.surveys import (
29+
create_mesh_by_line_id,
30+
get_parts_from_electrodes,
31+
)
2832

2933

3034
logger = getLogger(__name__)
3135

3236

33-
class Base2DDriver(InversionDriver):
34-
"""
35-
Base class for 2D DC and IP forward and inversion drivers.
36-
37-
Survey lines are inverted independently and internally stacked as a single
38-
long survey. The inversion mesh is created as a drape mesh over the survey lines.
39-
"""
40-
41-
@property
42-
def inversion_mesh(self) -> InversionMesh:
43-
"""Inversion mesh"""
44-
if getattr(self, "_inversion_mesh", None) is None:
45-
with fetch_active_workspace(self.workspace, mode="r+"):
46-
entity = None
47-
if self.params.mesh is None:
48-
entity = create_mesh_by_line_id(
49-
self.workspace,
50-
self.params.line_selection.line_object,
51-
self.params.drape_model,
52-
parent=self.out_group,
53-
)
54-
self.params.mesh = entity
55-
56-
self._inversion_mesh = InversionMesh(
57-
self.workspace, self.params, entity=entity
58-
)
59-
60-
return self._inversion_mesh
61-
62-
6337
class Base2DOptions(CoreOptions):
6438
"""
6539
Base options for the Direct Current 2D forward and inverse driver.
@@ -72,10 +46,13 @@ class Base2DOptions(CoreOptions):
7246
"""
7347

7448
data_object: PotentialElectrode
75-
line_selection: LineSelectionOptions = LineSelectionOptions()
49+
line_selection: LineSelectionOptions | None = None
7650
mesh: DrapeModel | Octree | None = None
7751
drape_model: DrapeModelOptions = DrapeModelOptions()
7852

53+
_line_parts: np.ndarray | None = None
54+
_selected_parts: list[int] | None = None
55+
7956
@field_validator("mesh", mode="before")
8057
@classmethod
8158
def mesh_cannot_be_octree(cls, value: Octree | DrapeModel):
@@ -105,3 +82,84 @@ def deprecated_pseudo(cls, data: dict):
10582
data["line_selection"] = line_selection
10683

10784
return data
85+
86+
@property
87+
def line_parts(self) -> np.ndarray:
88+
"""
89+
Generate monotonic line parts from line identifier or inferred from graph of potentials.
90+
"""
91+
if self._line_parts is None:
92+
if (
93+
self.line_selection is not None
94+
and self.line_selection.property is not None
95+
):
96+
_, self._line_parts = np.unique(
97+
self.line_selection.property.values, return_inverse=True
98+
)
99+
else:
100+
self._line_parts = get_parts_from_electrodes(self.data_object)
101+
102+
return self._line_parts
103+
104+
@property
105+
def selected_parts(self) -> list[int]:
106+
"""
107+
Translate line section ids to monotonic parts.
108+
"""
109+
if self._selected_parts is None:
110+
parts = []
111+
if (
112+
self.line_selection is not None
113+
and self.line_selection.property is not None
114+
):
115+
for count, val in enumerate(
116+
np.unique(self.line_selection.property.values)
117+
):
118+
if val in self.line_selection.value:
119+
parts.append(count)
120+
else:
121+
parts = np.arange(len(np.unique(self.line_parts)))
122+
123+
self._selected_parts = parts
124+
125+
return self._selected_parts
126+
127+
128+
class Base2DDriver(InversionDriver):
129+
"""
130+
Base class for 2D DC and IP forward and inversion drivers.
131+
132+
Survey lines are inverted independently and internally stacked as a single
133+
long survey. The inversion mesh is created as a drape mesh over the survey lines.
134+
"""
135+
136+
@property
137+
def inversion_mesh(self) -> InversionMesh:
138+
"""Inversion mesh"""
139+
if getattr(self, "_inversion_mesh", None) is None:
140+
with fetch_active_workspace(self.workspace, mode="r+"):
141+
entity = None
142+
if self.params.mesh is None:
143+
entity = create_mesh_by_line_id(
144+
self.workspace,
145+
self.params.line_ids,
146+
self.params.drape_model,
147+
parent=self.out_group,
148+
)
149+
self.params.mesh = entity
150+
151+
self._inversion_mesh = InversionMesh(
152+
self.workspace, self.params, entity=entity
153+
)
154+
155+
return self._inversion_mesh
156+
157+
def get_tiles(self) -> dict[str, list[np.ndarray]]:
158+
"""
159+
Generate tiles from survey parts.
160+
"""
161+
tiles = [
162+
[np.where(self.params.line_parts == part)[0]]
163+
for part in self.params.selected_parts
164+
]
165+
return {None: tiles}

simpeg_drivers/electromagnetics/base_1d_driver.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,20 @@ def get_1d_mesh(self) -> TensorMesh:
8686
)
8787
return layers_mesh
8888

89+
def get_tiles(self) -> dict[str, list[np.ndarray]]:
90+
n_data = self.inversion_data.mask.sum()
91+
indices = np.arange(n_data)
92+
93+
# Heuristic to avoid too many chunks
94+
n_chunks = n_data // self.params.compute.max_chunk_size
95+
96+
if self.workers:
97+
n_chunks /= len(self.workers)
98+
n_chunks = int(n_chunks) * len(self.workers)
99+
100+
n_chunks = np.max([n_chunks, 1, len(self.workers)])
101+
return {None: [[tile] for tile in np.array_split(indices, n_chunks)]}
102+
89103
@property
90104
def simulation(self):
91105
"""

0 commit comments

Comments
 (0)