Skip to content

Commit be320c8

Browse files
committed
Fix get_tiles for cases 1D and 2D
1 parent 0013f23 commit be320c8

2 files changed

Lines changed: 26 additions & 19 deletions

File tree

simpeg_drivers/driver.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __init__(
103103
"Disk storage of sensitivities is not compatible with distributed processing."
104104
)
105105

106-
self._workers: list[tuple[str]] | None = self.validate_workers(workers)
106+
self._workers: list[tuple[str]] = self.validate_workers(workers)
107107

108108
@property
109109
def out_group(self) -> SimPEGGroup:
@@ -799,33 +799,40 @@ def get_regularization(self):
799799

800800
return objective_function.ComboObjectiveFunction(objfcts=reg_funcs)
801801

802-
def get_tiles(self):
802+
def get_tiles(self) -> dict[str, list[np.ndarray]]:
803+
"""
804+
Parse the data locations into tiles for distributed processing.
805+
806+
Adapts differently to the inversion type (1D, 2D or 3D).
807+
808+
:return: Dictionary with channels as keys and list of tiles as values.
809+
"""
803810
n_data = self.inversion_data.mask.sum()
804811
indices = np.arange(n_data)
805812

806-
if "2d" in self.params.inversion_type:
807-
return [indices]
808-
813+
# Split tiles based on inversion type
809814
if "1d" in self.params.inversion_type:
810815
# Heuristic to avoid too many chunks
811816
n_chunks = n_data // self.params.compute.max_chunk_size
812817

813-
if self.params.compute.n_workers:
814-
n_chunks /= self.params.compute.n_workers
815-
n_chunks = int(n_chunks) * self.params.compute.n_workers
818+
if len(self.workers) > 0:
819+
n_chunks /= len(self.workers)
820+
n_chunks = int(n_chunks) * len(self.workers)
816821

817-
n_chunks = np.max([n_chunks, 1])
822+
n_chunks = np.max([n_chunks, 1, len(self.workers)])
823+
tiles = [[tile] for tile in np.array_split(indices, n_chunks)]
818824

819-
return np.array_split(indices, n_chunks)
825+
elif "2d" in self.params.inversion_type:
826+
tiles = [[indices]]
820827

821-
tiles = tile_locations(
822-
self.inversion_data.locations,
823-
self.params.compute.tile_spatial,
824-
labels=self.inversion_data.parts,
825-
sorting=self.simulation.survey.sorting,
826-
)
827-
828-
tiles = self.split_list(tiles)
828+
else:
829+
tiles = tile_locations(
830+
self.inversion_data.locations,
831+
self.params.compute.tile_spatial,
832+
labels=self.inversion_data.parts,
833+
sorting=self.simulation.survey.sorting,
834+
)
835+
tiles = self.split_list(tiles)
829836

830837
# Base slice over frequencies
831838
if self.params.inversion_type in ["magnetotellurics", "tipper", "fdem"]:

simpeg_drivers/options.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def mesh_cannot_be_rotated(cls, value: Octree):
225225

226226
@property
227227
def workpath(self):
228-
return Path(self.geoh5.h5file).parent
228+
return Path(self.geoh5.h5file).resolve().parent
229229

230230
@property
231231
def padding_cells(self) -> int:

0 commit comments

Comments
 (0)