Skip to content

Commit 998d6f5

Browse files
authored
Merge pull request #117 from MiraGeoscience/GEOPY-2461
GEOPY-2461: Failure of directives when using Futures for simulation
2 parents 7da0cb4 + 69a7888 commit 998d6f5

2 files changed

Lines changed: 47 additions & 27 deletions

File tree

simpeg/dask/objective_function.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
_validate_multiplier,
55
_check_length_objective_funcs_multipliers,
66
)
7-
7+
from typing import Callable
88
import numpy as np
99

1010
from dask.distributed import Client, Future
@@ -100,6 +100,9 @@ def _setter_broadcast(objfct, key, value):
100100
"""
101101
Broadcast a value to all workers.
102102
"""
103+
if isinstance(value, Callable):
104+
value = value(objfct)
105+
103106
if hasattr(objfct, key):
104107
setattr(objfct, key, value)
105108

@@ -565,3 +568,24 @@ def residuals(self, m, f=None):
565568
residuals += client.gather(future_residuals)
566569

567570
return residuals
571+
572+
def broadcast_updates(self, updates: dict):
573+
"""
574+
Set the attributes of the objective functions and simulations
575+
"""
576+
stores = []
577+
client = self.client
578+
579+
for fun, (key, value) in updates.items():
580+
worker = client.who_has(fun)[fun.key]
581+
stores.append(
582+
client.submit(
583+
_setter_broadcast,
584+
fun,
585+
key,
586+
value,
587+
workers=worker,
588+
)
589+
)
590+
591+
self.client.gather(stores) # blocking call to ensure all models were stored

simpeg/directives/_vector_models.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
2+
from dask.distributed import Future
33
from . import (
44
BaseSaveGeoH5,
55
InversionDirective,
@@ -56,8 +56,11 @@ def update(self):
5656
self.invProb.model = m
5757
self.opt.xc = self.invProb.model
5858

59-
for misfit in self.dmisfit.objfcts:
60-
misfit.simulation.model = m
59+
if isinstance(self.dmisfit.objfcts[0], Future):
60+
self.dmisfit.model = m
61+
else:
62+
for misfit in self.dmisfit.objfcts:
63+
misfit.simulation.model = m
6164

6265
def _reproject(self, m):
6366
"""
@@ -71,6 +74,15 @@ def _reproject(self, m):
7174
return m
7275

7376

77+
def update_map(misfit):
78+
if isinstance(misfit.simulation, MetaSimulation):
79+
misfit.simulation.simulations[0].chiMap = (
80+
SphericalSystem() * misfit.simulation.simulations[0].chiMap
81+
)
82+
else:
83+
misfit.simulation.chiMap = SphericalSystem() * misfit.simulation.chiMap
84+
85+
7486
class VectorInversion(InversionDirective):
7587
"""
7688
Control a vector inversion from Cartesian to spherical coordinates.
@@ -90,19 +102,15 @@ def __init__(
90102
self, simulations: list, regularizations: ComboObjectiveFunction, **kwargs
91103
):
92104
self.reference_angles = (False, False, False)
93-
self.simulations = simulations
105+
self.misfits = simulations
94106
self.regularizations = regularizations
95107

96108
set_kwargs(self, **kwargs)
97109

98110
@property
99111
def target(self):
100112
if getattr(self, "_target", None) is None:
101-
nD = 0
102-
for survey in self.survey:
103-
nD += survey.nD
104-
105-
self._target = nD * self.chifact_target
113+
self._target = np.hstack(self.invProb.dpred).shape[0] * self.chifact_target
106114

107115
return self._target
108116

@@ -116,10 +124,6 @@ def initialize(self):
116124

117125
self.reference_model = reg.reference_model
118126

119-
for dmisfit in self.dmisfit.objfcts:
120-
if getattr(dmisfit.simulation, "coordinate_system", None) is not None:
121-
dmisfit.simulation.coordinate_system = self.mode
122-
123127
def endIter(self):
124128

125129
model = self.invProb.model.copy()
@@ -217,20 +221,12 @@ def endIter(self):
217221
self.opt.upper[indices[nC:]] = np.inf
218222

219223
updates = {}
220-
for simulation in self.simulations:
221-
if isinstance(simulation, MetaSimulation):
222-
223-
if hasattr(self.dmisfit, "client"):
224-
updates[simulation] = (
225-
"chiMap",
226-
SphericalSystem() * simulation.simulations[0].chiMap,
227-
)
228-
else:
229-
simulation.simulations[0].chiMap = (
230-
SphericalSystem() * simulation.simulations[0].chiMap
231-
)
224+
for misfit in self.misfits:
225+
226+
if isinstance(misfit, Future):
227+
updates[misfit] = ("", update_map)
232228
else:
233-
simulation.chiMap = SphericalSystem() * simulation.chiMap
229+
update_map(misfit)
234230

235231
if hasattr(self.dmisfit, "client"):
236232
self.dmisfit.broadcast_updates(updates)

0 commit comments

Comments
 (0)