Skip to content

Commit 6f518ff

Browse files
committed
Protect agains setting model on simulation outside setter
1 parent a28d832 commit 6f518ff

1 file changed

Lines changed: 37 additions & 35 deletions

File tree

simpeg/directives/_vector_models.py

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
UpdateIRLS,
1010
UpdateSensitivityWeights,
1111
)
12-
from ..maps import SphericalSystem
12+
from ..maps import SphericalSystem, Wires
1313
from ..meta.simulation import MetaSimulation
1414
from ..objective_function import ComboObjectiveFunction
1515
from ..regularization import CrossGradient
@@ -25,51 +25,50 @@ class ProjectSphericalBounds(InversionDirective):
2525
spherical->cartesian->spherical
2626
"""
2727

28-
def initialize(self):
29-
x = self.invProb.model
30-
# Convert to cartesian than back to avoid over rotation
31-
nC = int(len(x) / 3)
32-
xyz = spherical2cartesian(x.reshape((nC, 3), order="F"))
33-
m = cartesian2spherical(xyz.reshape((nC, 3), order="F"))
34-
self.invProb.model = m
35-
self.opt.xc = m
28+
def __init__(self, mapping: Wires, **kwargs):
29+
if not isinstance(mapping, Wires):
30+
raise TypeError("mapping must be a Wires object")
3631

37-
for misfit in self.dmisfit:
38-
if getattr(misfit, "model_map", None) is not None:
39-
misfit.simulation.model = misfit.model_map @ m
40-
else:
41-
misfit.simulation.model = m
32+
if len(mapping.maps) != 3:
33+
raise ValueError("mapping must have 3 maps, one per vector component.")
34+
35+
self.indices = mapping.deriv(None).indices
36+
super().__init__(**kwargs)
37+
38+
def initialize(self):
39+
self.update()
4240

4341
def endIter(self):
44-
for misfit in self.dmisfit.objfcts:
45-
if (
46-
hasattr(misfit.simulation, "model_type")
47-
and misfit.simulation.model_type == "vector"
48-
):
49-
mapping = misfit.model_map.deriv(np.zeros(misfit.model_map.shape[1]))
50-
indices = (
51-
mapping.indices
52-
) # np.array(np.sum(mapping, axis=0)).flatten() > 0
53-
nC = int(len(indices) / 3)
54-
vec = self.invProb.model[indices]
55-
# Convert to cartesian than back to avoid over rotation
56-
xyz = spherical2cartesian(vec.reshape((nC, 3), order="F"))
57-
vec = cartesian2spherical(xyz.reshape((nC, 3), order="F"))
58-
self.invProb.model[indices] = vec
42+
self.update()
5943

44+
def update(self):
45+
"""
46+
Update the model and the simulation
47+
"""
48+
x = self.invProb.model
49+
m = self._reproject(x)
6050
phi_m_last = []
6151
for reg in self.reg.objfcts:
6252
reg.model = self.invProb.model
6353
phi_m_last += [reg(self.invProb.model)]
6454

6555
self.invProb.phi_m_last = phi_m_last
56+
self.invProb.model = m
6657
self.opt.xc = self.invProb.model
6758

6859
for misfit in self.dmisfit.objfcts:
69-
if getattr(misfit, "model_map", None) is not None:
70-
misfit.simulation.model = misfit.model_map @ self.invProb.model
71-
else:
72-
misfit.simulation.model = self.invProb.model
60+
misfit.simulation.model = m
61+
62+
def _reproject(self, m):
63+
"""
64+
Round trip conversion to reproject the model.
65+
"""
66+
vec = m[self.indices]
67+
xyz = spherical2cartesian(vec.reshape((-1, 3), order="F"))
68+
vec = cartesian2spherical(xyz.reshape((-1, 3), order="F"))
69+
70+
m[self.indices] = vec
71+
return m
7372

7473

7574
class VectorInversion(InversionDirective):
@@ -128,11 +127,13 @@ def endIter(self):
128127
print("Switching MVI to spherical coordinates")
129128
self.mode = "spherical"
130129
self.cartesian_model = self.invProb.model
131-
model = self.invProb.model
130+
model = self.invProb.model.copy()
132131
vec_model = []
133132
vec_ref = []
134133
indices = []
134+
mappings = []
135135
for reg in self.regularizations.objfcts:
136+
mappings.append(reg.mapping)
136137
vec_model.append(reg.mapping * model)
137138
vec_ref.append(reg.mapping * reg.reference_model)
138139
mapping = reg.mapping.deriv(np.zeros(reg.mapping.shape[1]))
@@ -220,8 +221,9 @@ def endIter(self):
220221
amplitude=self.regularizations.objfcts[0].mapping,
221222
angles=self.regularizations.objfcts[1:],
222223
)
224+
projections = [(comp, mapping) for comp, mapping in zip("xyz", mappings)]
223225
directiveList = [
224-
ProjectSphericalBounds(),
226+
ProjectSphericalBounds(Wires(*projections)),
225227
spherical_units,
226228
] + self.inversion.directiveList.dList
227229
self.inversion.directiveList = directiveList

0 commit comments

Comments
 (0)