@@ -103,7 +103,7 @@ def __init__(
103103 "Disk storage of sensitivities is not compatible with distributed processing."
104104 )
105105
106- self ._workers : list [tuple [str ]] | None = self .validate_workers (workers )
106+ self ._workers : list [tuple [str ]] = self .validate_workers (workers )
107107
108108 @property
109109 def out_group (self ) -> SimPEGGroup :
@@ -142,7 +142,7 @@ def validate_out_group(self, out_group: SimPEGGroup | None) -> SimPEGGroup:
142142 return out_group
143143
144144 @property
145- def client (self ) -> Client | bool | None :
145+ def client (self ) -> Client | bool :
146146 """
147147 Dask client or False if not using Dask.distributed.
148148 """
@@ -279,6 +279,7 @@ def __init__(
279279 self ._ordering : list [np .ndarray ] | None = None
280280 self ._mappings : list [maps .IdentityMap ] | None = None
281281 self ._window = None
282+ self .tiles : dict [list [np .ndarray ]]
282283
283284 def split_list (self , tiles : list [np .ndarray ]) -> list [np .ndarray ]:
284285 """
@@ -309,7 +310,9 @@ def split_list(self, tiles: list[np.ndarray]) -> list[np.ndarray]:
309310
310311 flat_tile_list = []
311312 for tile , split in zip (tiles , split_list ):
312- flat_tile_list .append (np .array_split (tile , split ))
313+ flat_tile_list .append (
314+ sub for sub in np .array_split (tile , split ) if len (sub ) > 0
315+ )
313316 return flat_tile_list
314317
315318 @property
@@ -318,16 +321,19 @@ def data_misfit(self):
318321 if getattr (self , "_data_misfit" , None ) is None :
319322 with fetch_active_workspace (self .workspace , mode = "r+" ):
320323 # Tile locations
321- tiles = self .get_tiles ()
322-
323- if self .logger :
324- self . logger . write ( f"Setting up { len ( tiles ) } tile(s) . . . \n " )
324+ if self . logger and self .params . compute . tile_spatial > 1 :
325+ self . logger . write (
326+ f"Setting up { self .params . compute . tile_spatial } tiles . . . \n "
327+ )
325328
329+ self .tiles = self .get_tiles ()
326330 self ._data_misfit = MisfitFactory (
327- self .params , self .client , self .simulation , self .workers
328- ).build (
329- self .split_list (tiles ),
330- )
331+ self .params ,
332+ self .simulation ,
333+ self .tiles ,
334+ client = self .client ,
335+ workers = self .workers ,
336+ ).build ()
331337
332338 return self ._data_misfit
333339
@@ -776,13 +782,24 @@ def get_tiles(self):
776782
777783 return np .array_split (indices , n_chunks )
778784
779- return tile_locations (
785+ tiles = tile_locations (
780786 self .inversion_data .locations ,
781787 self .params .compute .tile_spatial ,
782788 labels = self .inversion_data .parts ,
783789 sorting = self .simulation .survey .sorting ,
784790 )
785791
792+ self .split_list (tiles )
793+
794+ # Base slice over frequencies
795+ if self .params .inversion_type in ["magnetotellurics" , "tipper" , "fdem" ]:
796+ channels = self .simulation .survey .frequencies
797+ else :
798+ channels = [None ]
799+
800+ # Duplicate tiles for each channel
801+ return {channel : tiles for channel in channels }
802+
786803 @classmethod
787804 def start (cls , filepath : str | Path | InputFile , ** kwargs ) -> Self :
788805 """
0 commit comments