Skip to content

Commit 69a7888

Browse files
committed
Bring back broadcasting of attributes
1 parent 8907f0e commit 69a7888

1 file changed

Lines changed: 25 additions & 1 deletion

File tree

simpeg/dask/objective_function.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
_validate_multiplier,
55
_check_length_objective_funcs_multipliers,
66
)
7-
7+
from typing import Callable
88
import numpy as np
99

1010
from dask.distributed import Client, Future
@@ -100,6 +100,9 @@ def _setter_broadcast(objfct, key, value):
100100
"""
101101
Broadcast a value to all workers.
102102
"""
103+
if isinstance(value, Callable):
104+
value = value(objfct)
105+
103106
if hasattr(objfct, key):
104107
setattr(objfct, key, value)
105108

@@ -565,3 +568,24 @@ def residuals(self, m, f=None):
565568
residuals += client.gather(future_residuals)
566569

567570
return residuals
571+
572+
def broadcast_updates(self, updates: dict):
573+
"""
574+
Set the attributes of the objective functions and simulations
575+
"""
576+
stores = []
577+
client = self.client
578+
579+
for fun, (key, value) in updates.items():
580+
worker = client.who_has(fun)[fun.key]
581+
stores.append(
582+
client.submit(
583+
_setter_broadcast,
584+
fun,
585+
key,
586+
value,
587+
workers=worker,
588+
)
589+
)
590+
591+
self.client.gather(stores) # blocking call to ensure all models were stored

0 commit comments

Comments
 (0)