Skip to content

Commit 8907f0e

Browse files
committed
Add delayed methods for updateing simualtion within Future
1 parent 7da0cb4 commit 8907f0e

1 file changed

Lines changed: 22 additions & 26 deletions

File tree

simpeg/directives/_vector_models.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
2+
from dask.distributed import Future
33
from . import (
44
BaseSaveGeoH5,
55
InversionDirective,
@@ -56,8 +56,11 @@ def update(self):
5656
self.invProb.model = m
5757
self.opt.xc = self.invProb.model
5858

59-
for misfit in self.dmisfit.objfcts:
60-
misfit.simulation.model = m
59+
if isinstance(self.dmisfit.objfcts[0], Future):
60+
self.dmisfit.model = m
61+
else:
62+
for misfit in self.dmisfit.objfcts:
63+
misfit.simulation.model = m
6164

6265
def _reproject(self, m):
6366
"""
@@ -71,6 +74,15 @@ def _reproject(self, m):
7174
return m
7275

7376

77+
def update_map(misfit):
78+
if isinstance(misfit.simulation, MetaSimulation):
79+
misfit.simulation.simulations[0].chiMap = (
80+
SphericalSystem() * misfit.simulation.simulations[0].chiMap
81+
)
82+
else:
83+
misfit.simulation.chiMap = SphericalSystem() * misfit.simulation.chiMap
84+
85+
7486
class VectorInversion(InversionDirective):
7587
"""
7688
Control a vector inversion from Cartesian to spherical coordinates.
@@ -90,19 +102,15 @@ def __init__(
90102
self, simulations: list, regularizations: ComboObjectiveFunction, **kwargs
91103
):
92104
self.reference_angles = (False, False, False)
93-
self.simulations = simulations
105+
self.misfits = simulations
94106
self.regularizations = regularizations
95107

96108
set_kwargs(self, **kwargs)
97109

98110
@property
99111
def target(self):
100112
if getattr(self, "_target", None) is None:
101-
nD = 0
102-
for survey in self.survey:
103-
nD += survey.nD
104-
105-
self._target = nD * self.chifact_target
113+
self._target = np.hstack(self.invProb.dpred).shape[0] * self.chifact_target
106114

107115
return self._target
108116

@@ -116,10 +124,6 @@ def initialize(self):
116124

117125
self.reference_model = reg.reference_model
118126

119-
for dmisfit in self.dmisfit.objfcts:
120-
if getattr(dmisfit.simulation, "coordinate_system", None) is not None:
121-
dmisfit.simulation.coordinate_system = self.mode
122-
123127
def endIter(self):
124128

125129
model = self.invProb.model.copy()
@@ -217,20 +221,12 @@ def endIter(self):
217221
self.opt.upper[indices[nC:]] = np.inf
218222

219223
updates = {}
220-
for simulation in self.simulations:
221-
if isinstance(simulation, MetaSimulation):
222-
223-
if hasattr(self.dmisfit, "client"):
224-
updates[simulation] = (
225-
"chiMap",
226-
SphericalSystem() * simulation.simulations[0].chiMap,
227-
)
228-
else:
229-
simulation.simulations[0].chiMap = (
230-
SphericalSystem() * simulation.simulations[0].chiMap
231-
)
224+
for misfit in self.misfits:
225+
226+
if isinstance(misfit, Future):
227+
updates[misfit] = ("", update_map)
232228
else:
233-
simulation.chiMap = SphericalSystem() * simulation.chiMap
229+
update_map(misfit)
234230

235231
if hasattr(self.dmisfit, "client"):
236232
self.dmisfit.broadcast_updates(updates)

0 commit comments

Comments
 (0)