@@ -103,7 +103,7 @@ def __init__(
103103 "Disk storage of sensitivities is not compatible with distributed processing."
104104 )
105105
106- self ._workers : list [tuple [str ]] | None = self .validate_workers (workers )
106+ self ._workers : list [tuple [str ]] = self .validate_workers (workers )
107107
108108 @property
109109 def out_group (self ) -> SimPEGGroup :
@@ -799,33 +799,40 @@ def get_regularization(self):
799799
800800 return objective_function .ComboObjectiveFunction (objfcts = reg_funcs )
801801
802- def get_tiles (self ):
802+ def get_tiles (self ) -> dict [str , list [np .ndarray ]]:
803+ """
804+ Parse the data locations into tiles for distributed processing.
805+
806+ Adapts differently to the inversion type (1D, 2D or 3D).
807+
808+ :return: Dictionary with channels as keys and list of tiles as values.
809+ """
803810 n_data = self .inversion_data .mask .sum ()
804811 indices = np .arange (n_data )
805812
806- if "2d" in self .params .inversion_type :
807- return [indices ]
808-
813+ # Split tiles based on inversion type
809814 if "1d" in self .params .inversion_type :
810815 # Heuristic to avoid too many chunks
811816 n_chunks = n_data // self .params .compute .max_chunk_size
812817
813- if self .params . compute . n_workers :
814- n_chunks /= self .params . compute . n_workers
815- n_chunks = int (n_chunks ) * self .params . compute . n_workers
818+ if len ( self .workers ) > 0 :
819+ n_chunks /= len ( self .workers )
820+ n_chunks = int (n_chunks ) * len ( self .workers )
816821
817- n_chunks = np .max ([n_chunks , 1 ])
822+ n_chunks = np .max ([n_chunks , 1 , len (self .workers )])
823+ tiles = [[tile ] for tile in np .array_split (indices , n_chunks )]
818824
819- return np .array_split (indices , n_chunks )
825+ elif "2d" in self .params .inversion_type :
826+ tiles = [[indices ]]
820827
821- tiles = tile_locations (
822- self . inversion_data . locations ,
823- self .params . compute . tile_spatial ,
824- labels = self .inversion_data . parts ,
825- sorting = self .simulation . survey . sorting ,
826- )
827-
828- tiles = self .split_list (tiles )
828+ else :
829+ tiles = tile_locations (
830+ self .inversion_data . locations ,
831+ self .params . compute . tile_spatial ,
832+ labels = self .inversion_data . parts ,
833+ sorting = self . simulation . survey . sorting ,
834+ )
835+ tiles = self .split_list (tiles )
829836
830837 # Base slice over frequencies
831838 if self .params .inversion_type in ["magnetotellurics" , "tipper" , "fdem" ]:
0 commit comments