@@ -103,7 +103,7 @@ def __init__(
103103 "Disk storage of sensitivities is not compatible with distributed processing."
104104 )
105105
106- self ._workers : list [tuple [str ]] = self .validate_workers (workers )
106+ self ._workers : list [tuple [str ]] | None = self .validate_workers (workers )
107107
108108 @property
109109 def out_group (self ) -> SimPEGGroup :
@@ -142,7 +142,7 @@ def validate_out_group(self, out_group: SimPEGGroup | None) -> SimPEGGroup:
142142 return out_group
143143
144144 @property
145- def client (self ) -> Client | bool :
145+ def client (self ) -> Client | bool | None :
146146 """
147147 Dask client or False if not using Dask.distributed.
148148 """
@@ -279,7 +279,6 @@ def __init__(
279279 self ._ordering : list [np .ndarray ] | None = None
280280 self ._mappings : list [maps .IdentityMap ] | None = None
281281 self ._window = None
282- self .tiles : dict [list [np .ndarray ]]
283282
284283 def split_list (self , tiles : list [np .ndarray ]) -> list [np .ndarray ]:
285284 """
@@ -310,9 +309,7 @@ def split_list(self, tiles: list[np.ndarray]) -> list[np.ndarray]:
310309
311310 flat_tile_list = []
312311 for tile , split in zip (tiles , split_list ):
313- flat_tile_list .append (
314- sub for sub in np .array_split (tile , split ) if len (sub ) > 0
315- )
312+ flat_tile_list .append (np .array_split (tile , split ))
316313 return flat_tile_list
317314
318315 @property
@@ -321,19 +318,16 @@ def data_misfit(self):
321318 if getattr (self , "_data_misfit" , None ) is None :
322319 with fetch_active_workspace (self .workspace , mode = "r+" ):
323320 # Tile locations
324- if self . logger and self .params . compute . tile_spatial > 1 :
325- self . logger . write (
326- f"Setting up { self .params . compute . tile_spatial } tiles . . . \n "
327- )
321+ tiles = self .get_tiles ()
322+
323+ if self .logger :
324+ self . logger . write ( f"Setting up { len ( tiles ) } tile(s) . . . \n " )
328325
329- self .tiles = self .get_tiles ()
330326 self ._data_misfit = MisfitFactory (
331- self .params ,
332- self .simulation ,
333- self .tiles ,
334- client = self .client ,
335- workers = self .workers ,
336- ).build ()
327+ self .params , self .client , self .simulation , self .workers
328+ ).build (
329+ self .split_list (tiles ),
330+ )
337331
338332 return self ._data_misfit
339333
@@ -782,24 +776,13 @@ def get_tiles(self):
782776
783777 return np .array_split (indices , n_chunks )
784778
785- tiles = tile_locations (
779+ return tile_locations (
786780 self .inversion_data .locations ,
787781 self .params .compute .tile_spatial ,
788782 labels = self .inversion_data .parts ,
789783 sorting = self .simulation .survey .sorting ,
790784 )
791785
792- self .split_list (tiles )
793-
794- # Base slice over frequencies
795- if self .params .inversion_type in ["magnetotellurics" , "tipper" , "fdem" ]:
796- channels = self .simulation .survey .frequencies
797- else :
798- channels = [None ]
799-
800- # Duplicate tiles for each channel
801- return {channel : tiles for channel in channels }
802-
803786 @classmethod
804787 def start (cls , filepath : str | Path | InputFile , ** kwargs ) -> Self :
805788 """
0 commit comments