77from typing import Callable
88import numpy as np
99
10- from dask .distributed import Client , Future
10+ from dask .distributed import Client , Future , get_client
1111from ..data_misfit import L2DataMisfit
1212
1313from simpeg .utils import validate_list_of_types
@@ -145,9 +145,7 @@ def _validate_type_or_future_of_type(
145145):
146146
147147 if workers is None :
148- workers = [
149- (worker .worker_address ,) for worker in client .cluster .workers .values ()
150- ]
148+ workers = [(worker ,) for worker in client .nthreads ()]
151149
152150 objects = validate_list_of_types (
153151 property_name , objects , obj_type , ensure_unique = True
@@ -262,6 +260,9 @@ def client(self):
262260
263261 @client .setter
264262 def client (self , client ):
263+ if client is None :
264+ client = get_client ()
265+
265266 if not isinstance (client , Client ):
266267 raise TypeError ("client must be a dask.distributed.Client" )
267268
@@ -279,6 +280,23 @@ def workers(self, workers):
279280 if not isinstance (workers , list | type (None )):
280281 raise TypeError ("workers must be a list of strings" )
281282
283+ available_workers = [(worker ,) for worker in self .client .nthreads ()]
284+
285+ if workers is None :
286+ workers = available_workers
287+
288+ if not isinstance (workers , list ) or not all (
289+ isinstance (w , tuple ) for w in workers
290+ ):
291+ raise TypeError ("Workers must be a list of tuple[str]." )
292+
293+ invalid_workers = [w for w in workers if w not in available_workers ]
294+ if invalid_workers :
295+ raise ValueError (
296+ f"The following workers are not available: { invalid_workers } . "
297+ f"Available workers are: { available_workers } ."
298+ )
299+
282300 self ._workers = workers
283301
284302 def deriv (self , m , f = None ):
0 commit comments