99 UpdateIRLS ,
1010 UpdateSensitivityWeights ,
1111)
12- from ..maps import SphericalSystem
12+ from ..maps import SphericalSystem , Wires
1313from ..meta .simulation import MetaSimulation
1414from ..objective_function import ComboObjectiveFunction
1515from ..regularization import CrossGradient
@@ -25,51 +25,50 @@ class ProjectSphericalBounds(InversionDirective):
2525 spherical->cartesian->spherical
2626 """
2727
28- def initialize (self ):
29- x = self .invProb .model
30- # Convert to cartesian than back to avoid over rotation
31- nC = int (len (x ) / 3 )
32- xyz = spherical2cartesian (x .reshape ((nC , 3 ), order = "F" ))
33- m = cartesian2spherical (xyz .reshape ((nC , 3 ), order = "F" ))
34- self .invProb .model = m
35- self .opt .xc = m
28+ def __init__ (self , mapping : Wires , ** kwargs ):
29+ if not isinstance (mapping , Wires ):
30+ raise TypeError ("mapping must be a Wires object" )
3631
37- for misfit in self .dmisfit :
38- if getattr (misfit , "model_map" , None ) is not None :
39- misfit .simulation .model = misfit .model_map @ m
40- else :
41- misfit .simulation .model = m
32+ if len (mapping .maps ) != 3 :
33+ raise ValueError ("mapping must have 3 maps, one per vector component." )
34+
35+ self .indices = mapping .deriv (None ).indices
36+ super ().__init__ (** kwargs )
37+
38+ def initialize (self ):
39+ self .update ()
4240
4341 def endIter (self ):
44- for misfit in self .dmisfit .objfcts :
45- if (
46- hasattr (misfit .simulation , "model_type" )
47- and misfit .simulation .model_type == "vector"
48- ):
49- mapping = misfit .model_map .deriv (np .zeros (misfit .model_map .shape [1 ]))
50- indices = (
51- mapping .indices
52- ) # np.array(np.sum(mapping, axis=0)).flatten() > 0
53- nC = int (len (indices ) / 3 )
54- vec = self .invProb .model [indices ]
55- # Convert to cartesian than back to avoid over rotation
56- xyz = spherical2cartesian (vec .reshape ((nC , 3 ), order = "F" ))
57- vec = cartesian2spherical (xyz .reshape ((nC , 3 ), order = "F" ))
58- self .invProb .model [indices ] = vec
42+ self .update ()
5943
44+ def update (self ):
45+ """
46+ Update the model and the simulation
47+ """
48+ x = self .invProb .model
49+ m = self ._reproject (x )
6050 phi_m_last = []
6151 for reg in self .reg .objfcts :
6252 reg .model = self .invProb .model
6353 phi_m_last += [reg (self .invProb .model )]
6454
6555 self .invProb .phi_m_last = phi_m_last
56+ self .invProb .model = m
6657 self .opt .xc = self .invProb .model
6758
6859 for misfit in self .dmisfit .objfcts :
69- if getattr (misfit , "model_map" , None ) is not None :
70- misfit .simulation .model = misfit .model_map @ self .invProb .model
71- else :
72- misfit .simulation .model = self .invProb .model
60+ misfit .simulation .model = m
61+
62+ def _reproject (self , m ):
63+ """
64+ Round trip conversion to reproject the model.
65+ """
66+ vec = m [self .indices ]
67+ xyz = spherical2cartesian (vec .reshape ((- 1 , 3 ), order = "F" ))
68+ vec = cartesian2spherical (xyz .reshape ((- 1 , 3 ), order = "F" ))
69+
70+ m [self .indices ] = vec
71+ return m
7372
7473
7574class VectorInversion (InversionDirective ):
@@ -128,11 +127,13 @@ def endIter(self):
128127 print ("Switching MVI to spherical coordinates" )
129128 self .mode = "spherical"
130129 self .cartesian_model = self .invProb .model
131- model = self .invProb .model
130+ model = self .invProb .model . copy ()
132131 vec_model = []
133132 vec_ref = []
134133 indices = []
134+ mappings = []
135135 for reg in self .regularizations .objfcts :
136+ mappings .append (reg .mapping )
136137 vec_model .append (reg .mapping * model )
137138 vec_ref .append (reg .mapping * reg .reference_model )
138139 mapping = reg .mapping .deriv (np .zeros (reg .mapping .shape [1 ]))
@@ -220,8 +221,9 @@ def endIter(self):
220221 amplitude = self .regularizations .objfcts [0 ].mapping ,
221222 angles = self .regularizations .objfcts [1 :],
222223 )
224+ projections = [(comp , mapping ) for comp , mapping in zip ("xyz" , mappings )]
223225 directiveList = [
224- ProjectSphericalBounds (),
226+ ProjectSphericalBounds (Wires ( * projections ) ),
225227 spherical_units ,
226228 ] + self .inversion .directiveList .dList
227229 self .inversion .directiveList = directiveList
0 commit comments