@@ -15,6 +15,10 @@ def _calc_dpred(objfct, _):
1515 return objfct .simulation .dpred (m = objfct .simulation .model )
1616
1717
18+ def _calc_objective (objfct , multiplier , model ):
19+ return multiplier * objfct (model )
20+
21+
1822def _calc_residual (objfct , _ ):
1923 return objfct .W * (
2024 objfct .data .dobs - objfct .simulation .dpred (m = objfct .simulation .model )
@@ -33,6 +37,18 @@ def _store_model(objfct, model):
3337 objfct .simulation .model = model
3438
3539
40+ def _setter_broadcast (objfct , key , value ):
41+ """
42+ Broadcast a value to all workers.
43+ """
44+ if hasattr (objfct , key ):
45+ setattr (objfct , key , value )
46+
47+ for sim in objfct .simulation .simulations :
48+ if hasattr (sim , key ):
49+ setattr (sim , key , value )
50+
51+
3652def _get_jtj_diag (objfct , _ ):
3753 jtj = objfct .simulation .getJtJdiag (objfct .simulation .model , objfct .W )
3854 return jtj .flatten ()
@@ -56,14 +72,15 @@ def _validate_type_or_future_of_type(
5672 property_name , objects , obj_type , ensure_unique = True
5773 )
5874 workload = [[]]
75+ lookup = {}
5976 count = 0
6077 for obj in objects :
6178 if count == len (workers ):
6279 count = 0
6380 workload .append ([])
6481 obj .simulation .simulations [0 ].worker = workers [count ]
6582 future = client .scatter ([obj ], workers = workers [count ])[0 ]
66-
83+ lookup [ obj ] = ( future , workers [ count ])
6784 if hasattr (obj , "name" ):
6885 future .name = obj .name
6986
@@ -84,7 +101,7 @@ def _validate_type_or_future_of_type(
84101 raise TypeError (f"{ property_name } futures must be an instance of { obj_type } " )
85102
86103 if return_workers :
87- return workload , workers
104+ return workload , workers , lookup
88105 else :
89106 return workload
90107
@@ -374,7 +391,7 @@ def objfcts(self):
374391 def objfcts (self , objfcts ):
375392 client = self .client
376393
377- futures , workers = _validate_type_or_future_of_type (
394+ futures , workers , lookup = _validate_type_or_future_of_type (
378395 "objfcts" ,
379396 objfcts ,
380397 L2DataMisfit ,
@@ -387,6 +404,11 @@ def objfcts(self, objfcts):
387404 self ._futures = futures
388405 self ._workers = workers
389406
407+ self ._lookup = {
408+ misfit .simulation : (future , worker )
409+ for misfit , (future , worker ) in lookup .items ()
410+ }
411+
390412 def residuals (self , m , f = None ):
391413 """
392414 Compute the residual for the data misfit.
@@ -411,3 +433,26 @@ def residuals(self, m, f=None):
411433 residuals += client .gather (future_residuals )
412434
413435 return residuals
436+
437+ def broadcast_updates (self , updates : dict ):
438+ """
439+ Set the attributes of the objective functions and simulations
440+ """
441+ stores = []
442+ client = self .client
443+ for fun , (key , value ) in updates .items ():
444+ if fun not in self ._lookup :
445+ continue
446+
447+ future , worker = self ._lookup [fun ]
448+
449+ stores .append (
450+ client .submit (
451+ _setter_broadcast ,
452+ future ,
453+ key ,
454+ value ,
455+ workers = worker ,
456+ )
457+ )
458+ self .client .gather (stores ) # blocking call to ensure all models were stored
0 commit comments