Skip to content

Commit 27934f8

Browse files
committed
Clean up of inversion model module. Allow to not trim the dip azimut models
1 parent e25f99d commit 27934f8

2 files changed

Lines changed: 58 additions & 67 deletions

File tree

simpeg_drivers/components/models.py

Lines changed: 57 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,25 @@
2929
from simpeg_drivers.driver import InversionDriver
3030

3131

32+
MODEL_TYPES = [
33+
"starting",
34+
"reference",
35+
"lower_bound",
36+
"upper_bound",
37+
"conductivity",
38+
"alpha_s",
39+
"length_scale_x",
40+
"length_scale_y",
41+
"length_scale_z",
42+
"gradient_dip",
43+
"gradient_direction",
44+
"s_norm",
45+
"x_norm",
46+
"y_norm",
47+
"z_norm",
48+
]
49+
50+
3251
class InversionModelCollection:
3352
"""
3453
Collection of inversion models.
@@ -39,50 +58,41 @@ class InversionModelCollection:
3958
4059
"""
4160

42-
model_types = [
43-
"starting",
44-
"reference",
45-
"lower_bound",
46-
"upper_bound",
47-
"conductivity",
48-
"alpha_s",
49-
"length_scale_x",
50-
"length_scale_y",
51-
"length_scale_z",
52-
"s_norm",
53-
"x_norm",
54-
"y_norm",
55-
"z_norm",
56-
]
57-
5861
def __init__(self, driver: InversionDriver):
5962
"""
6063
:param driver: Parental InversionDriver class.
6164
"""
6265
self._active_cells: np.ndarray | None = None
6366
self._driver = driver
6467
self.is_sigma = self.driver.params.physical_property == "conductivity"
65-
self.is_vector = (
68+
is_vector = (
6669
True if self.driver.params.inversion_type == "magnetic vector" else False
6770
)
68-
self.n_blocks = (
69-
3 if self.driver.params.inversion_type == "magnetic vector" else 1
71+
self._starting = InversionModel(driver, "starting", is_vector=is_vector)
72+
self._reference = InversionModel(driver, "reference", is_vector=is_vector)
73+
self._lower_bound = InversionModel(driver, "lower_bound", is_vector=is_vector)
74+
self._upper_bound = InversionModel(driver, "upper_bound", is_vector=is_vector)
75+
self._conductivity = InversionModel(driver, "conductivity", is_vector=is_vector)
76+
self._alpha_s = InversionModel(driver, "alpha_s", is_vector=is_vector)
77+
self._length_scale_x = InversionModel(
78+
driver, "length_scale_x", is_vector=is_vector
79+
)
80+
self._length_scale_y = InversionModel(
81+
driver, "length_scale_y", is_vector=is_vector
82+
)
83+
self._length_scale_z = InversionModel(
84+
driver, "length_scale_z", is_vector=is_vector
85+
)
86+
self._gradient_dip = InversionModel(
87+
driver, "gradient_dip", trim_active_cells=False
7088
)
71-
self._starting = InversionModel(driver, "starting")
72-
self._reference = InversionModel(driver, "reference")
73-
self._lower_bound = InversionModel(driver, "lower_bound")
74-
self._upper_bound = InversionModel(driver, "upper_bound")
75-
self._conductivity = InversionModel(driver, "conductivity")
76-
self._alpha_s = InversionModel(driver, "alpha_s")
77-
self._length_scale_x = InversionModel(driver, "length_scale_x")
78-
self._length_scale_y = InversionModel(driver, "length_scale_y")
79-
self._length_scale_z = InversionModel(driver, "length_scale_z")
80-
self._gradient_dip = InversionModel(driver, "gradient_dip")
81-
self._gradient_direction = InversionModel(driver, "gradient_direction")
82-
self._s_norm = InversionModel(driver, "s_norm")
83-
self._x_norm = InversionModel(driver, "x_norm")
84-
self._y_norm = InversionModel(driver, "y_norm")
85-
self._z_norm = InversionModel(driver, "z_norm")
89+
self._gradient_direction = InversionModel(
90+
driver, "gradient_direction", trim_active_cells=False
91+
)
92+
self._s_norm = InversionModel(driver, "s_norm", is_vector=is_vector)
93+
self._x_norm = InversionModel(driver, "x_norm", is_vector=is_vector)
94+
self._y_norm = InversionModel(driver, "y_norm", is_vector=is_vector)
95+
self._z_norm = InversionModel(driver, "z_norm", is_vector=is_vector)
8696

8797
@property
8898
def n_active(self) -> int:
@@ -307,7 +317,7 @@ def z_norm(self) -> np.ndarray | None:
307317
def _model_method_wrapper(self, method, name=None, **kwargs):
308318
"""wraps individual model's specific method and applies in loop over model types."""
309319
returned_items = {}
310-
for mtype in self.model_types:
320+
for mtype in MODEL_TYPES:
311321
model = getattr(self, f"_{mtype}")
312322
if model.model is not None:
313323
f = getattr(model, method)
@@ -364,43 +374,24 @@ class InversionModel:
364374
remove_air: Use active cells vector to remove air cells from model.
365375
"""
366376

367-
model_types = [
368-
"starting",
369-
"reference",
370-
"lower_bound",
371-
"upper_bound",
372-
"conductivity",
373-
"alpha_s",
374-
"length_scale_x",
375-
"length_scale_y",
376-
"length_scale_z",
377-
"gradient_dip",
378-
"gradient_direction",
379-
"s_norm",
380-
"x_norm",
381-
"y_norm",
382-
"z_norm",
383-
]
384-
385377
def __init__(
386378
self,
387379
driver: InversionDriver,
388380
model_type: str,
381+
is_vector: bool = False,
382+
trim_active_cells: bool = True,
389383
):
390384
"""
391385
:param driver: InversionDriver object.
392-
:param model_type: Type of inversion model, can be any of "starting", "reference",
393-
"lower_bound", "upper_bound".
386+
:param model_type: Type of inversion model, can be any of MODEL_TYPES.
387+
:param is_vector: If True, model is a vector.
388+
:param trim_active_cells: If True, remove air cells from model.
394389
"""
395390
self.driver = driver
396391
self.model_type = model_type
397392
self.model: np.ndarray | None = None
398-
self.is_vector = (
399-
True if self.driver.params.inversion_type == "magnetic vector" else False
400-
)
401-
self.n_blocks = (
402-
3 if self.driver.params.inversion_type == "magnetic vector" else 1
403-
)
393+
self.is_vector = is_vector
394+
self.trim_active_cells = trim_active_cells
404395
self._initialize()
405396

406397
def _initialize(self):
@@ -452,7 +443,7 @@ def _initialize(self):
452443
and self.is_vector
453444
and model.shape[0] == self.driver.inversion_mesh.n_cells
454445
):
455-
model = np.tile(model, self.n_blocks)
446+
model = np.tile(model, 3 if self.is_vector else 1)
456447

457448
if model is not None:
458449
self.model = mkvc(model)
@@ -461,8 +452,8 @@ def _initialize(self):
461452
def remove_air(self, active_cells):
462453
"""Use active cells vector to remove air cells from model"""
463454

464-
if self.model is not None:
465-
self.model = self.model[np.tile(active_cells, self.n_blocks)]
455+
if self.model is not None and self.trim_active_cells:
456+
self.model = self.model[np.tile(active_cells, 3 if self.is_vector else 1)]
466457

467458
def permute_2_octree(self) -> np.ndarray | None:
468459
"""
@@ -604,7 +595,7 @@ def model_type(self):
604595

605596
@model_type.setter
606597
def model_type(self, v):
607-
if v not in self.model_types:
608-
msg = f"Invalid model_type: {v}. Must be one of {(*self.model_types,)}."
598+
if v not in MODEL_TYPES:
599+
msg = f"Invalid model_type: {v}. Must be one of {(*MODEL_TYPES,)}."
609600
raise ValueError(msg)
610601
self._model_type = v

simpeg_drivers/driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def get_regularization(self):
459459
reference_model=self.models.reference,
460460
)
461461

462-
if is_rotated and not neighbors:
462+
if is_rotated and neighbors is None:
463463
backward_mesh = RegularizationMesh(
464464
self.inversion_mesh.mesh, active_cells=self.models.active_cells
465465
)

0 commit comments

Comments
 (0)