Skip to content

Commit ff56ac1

Browse files
authored
Merge pull request #179 from MiraGeoscience/GEOPY-2075_c
GEOPY-2075: c
2 parents f63c829 + fdfcc86 commit ff56ac1

3 files changed

Lines changed: 174 additions & 130 deletions

File tree

simpeg_drivers/components/models.py

Lines changed: 57 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,25 @@
2929
from simpeg_drivers.driver import InversionDriver
3030

3131

32+
MODEL_TYPES = [
33+
"starting",
34+
"reference",
35+
"lower_bound",
36+
"upper_bound",
37+
"conductivity",
38+
"alpha_s",
39+
"length_scale_x",
40+
"length_scale_y",
41+
"length_scale_z",
42+
"gradient_dip",
43+
"gradient_direction",
44+
"s_norm",
45+
"x_norm",
46+
"y_norm",
47+
"z_norm",
48+
]
49+
50+
3251
class InversionModelCollection:
3352
"""
3453
Collection of inversion models.
@@ -39,50 +58,41 @@ class InversionModelCollection:
3958
4059
"""
4160

42-
model_types = [
43-
"starting",
44-
"reference",
45-
"lower_bound",
46-
"upper_bound",
47-
"conductivity",
48-
"alpha_s",
49-
"length_scale_x",
50-
"length_scale_y",
51-
"length_scale_z",
52-
"s_norm",
53-
"x_norm",
54-
"y_norm",
55-
"z_norm",
56-
]
57-
5861
def __init__(self, driver: InversionDriver):
5962
"""
6063
:param driver: Parental InversionDriver class.
6164
"""
6265
self._active_cells: np.ndarray | None = None
6366
self._driver = driver
6467
self.is_sigma = self.driver.params.physical_property == "conductivity"
65-
self.is_vector = (
68+
is_vector = (
6669
True if self.driver.params.inversion_type == "magnetic vector" else False
6770
)
68-
self.n_blocks = (
69-
3 if self.driver.params.inversion_type == "magnetic vector" else 1
71+
self._starting = InversionModel(driver, "starting", is_vector=is_vector)
72+
self._reference = InversionModel(driver, "reference", is_vector=is_vector)
73+
self._lower_bound = InversionModel(driver, "lower_bound", is_vector=is_vector)
74+
self._upper_bound = InversionModel(driver, "upper_bound", is_vector=is_vector)
75+
self._conductivity = InversionModel(driver, "conductivity", is_vector=is_vector)
76+
self._alpha_s = InversionModel(driver, "alpha_s", is_vector=is_vector)
77+
self._length_scale_x = InversionModel(
78+
driver, "length_scale_x", is_vector=is_vector
79+
)
80+
self._length_scale_y = InversionModel(
81+
driver, "length_scale_y", is_vector=is_vector
82+
)
83+
self._length_scale_z = InversionModel(
84+
driver, "length_scale_z", is_vector=is_vector
85+
)
86+
self._gradient_dip = InversionModel(
87+
driver, "gradient_dip", trim_active_cells=False
7088
)
71-
self._starting = InversionModel(driver, "starting")
72-
self._reference = InversionModel(driver, "reference")
73-
self._lower_bound = InversionModel(driver, "lower_bound")
74-
self._upper_bound = InversionModel(driver, "upper_bound")
75-
self._conductivity = InversionModel(driver, "conductivity")
76-
self._alpha_s = InversionModel(driver, "alpha_s")
77-
self._length_scale_x = InversionModel(driver, "length_scale_x")
78-
self._length_scale_y = InversionModel(driver, "length_scale_y")
79-
self._length_scale_z = InversionModel(driver, "length_scale_z")
80-
self._gradient_dip = InversionModel(driver, "gradient_dip")
81-
self._gradient_direction = InversionModel(driver, "gradient_direction")
82-
self._s_norm = InversionModel(driver, "s_norm")
83-
self._x_norm = InversionModel(driver, "x_norm")
84-
self._y_norm = InversionModel(driver, "y_norm")
85-
self._z_norm = InversionModel(driver, "z_norm")
89+
self._gradient_direction = InversionModel(
90+
driver, "gradient_direction", trim_active_cells=False
91+
)
92+
self._s_norm = InversionModel(driver, "s_norm", is_vector=is_vector)
93+
self._x_norm = InversionModel(driver, "x_norm", is_vector=is_vector)
94+
self._y_norm = InversionModel(driver, "y_norm", is_vector=is_vector)
95+
self._z_norm = InversionModel(driver, "z_norm", is_vector=is_vector)
8696

8797
@property
8898
def n_active(self) -> int:
@@ -307,7 +317,7 @@ def z_norm(self) -> np.ndarray | None:
307317
def _model_method_wrapper(self, method, name=None, **kwargs):
308318
"""wraps individual model's specific method and applies in loop over model types."""
309319
returned_items = {}
310-
for mtype in self.model_types:
320+
for mtype in MODEL_TYPES:
311321
model = getattr(self, f"_{mtype}")
312322
if model.model is not None:
313323
f = getattr(model, method)
@@ -364,43 +374,24 @@ class InversionModel:
364374
remove_air: Use active cells vector to remove air cells from model.
365375
"""
366376

367-
model_types = [
368-
"starting",
369-
"reference",
370-
"lower_bound",
371-
"upper_bound",
372-
"conductivity",
373-
"alpha_s",
374-
"length_scale_x",
375-
"length_scale_y",
376-
"length_scale_z",
377-
"gradient_dip",
378-
"gradient_direction",
379-
"s_norm",
380-
"x_norm",
381-
"y_norm",
382-
"z_norm",
383-
]
384-
385377
def __init__(
386378
self,
387379
driver: InversionDriver,
388380
model_type: str,
381+
is_vector: bool = False,
382+
trim_active_cells: bool = True,
389383
):
390384
"""
391385
:param driver: InversionDriver object.
392-
:param model_type: Type of inversion model, can be any of "starting", "reference",
393-
"lower_bound", "upper_bound".
386+
:param model_type: Type of inversion model, can be any of MODEL_TYPES.
387+
:param is_vector: If True, model is a vector.
388+
:param trim_active_cells: If True, remove air cells from model.
394389
"""
395390
self.driver = driver
396391
self.model_type = model_type
397392
self.model: np.ndarray | None = None
398-
self.is_vector = (
399-
True if self.driver.params.inversion_type == "magnetic vector" else False
400-
)
401-
self.n_blocks = (
402-
3 if self.driver.params.inversion_type == "magnetic vector" else 1
403-
)
393+
self.is_vector = is_vector
394+
self.trim_active_cells = trim_active_cells
404395
self._initialize()
405396

406397
def _initialize(self):
@@ -452,7 +443,7 @@ def _initialize(self):
452443
and self.is_vector
453444
and model.shape[0] == self.driver.inversion_mesh.n_cells
454445
):
455-
model = np.tile(model, self.n_blocks)
446+
model = np.tile(model, 3 if self.is_vector else 1)
456447

457448
if model is not None:
458449
self.model = mkvc(model)
@@ -461,8 +452,8 @@ def _initialize(self):
461452
def remove_air(self, active_cells):
462453
"""Use active cells vector to remove air cells from model"""
463454

464-
if self.model is not None:
465-
self.model = self.model[np.tile(active_cells, self.n_blocks)]
455+
if self.model is not None and self.trim_active_cells:
456+
self.model = self.model[np.tile(active_cells, 3 if self.is_vector else 1)]
466457

467458
def permute_2_octree(self) -> np.ndarray | None:
468459
"""
@@ -604,7 +595,7 @@ def model_type(self):
604595

605596
@model_type.setter
606597
def model_type(self, v):
607-
if v not in self.model_types:
608-
msg = f"Invalid model_type: {v}. Must be one of {(*self.model_types,)}."
598+
if v not in MODEL_TYPES:
599+
msg = f"Invalid model_type: {v}. Must be one of {(*MODEL_TYPES,)}."
609600
raise ValueError(msg)
610601
self._model_type = v

simpeg_drivers/driver.py

Lines changed: 79 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
import multiprocessing
17-
17+
from copy import deepcopy
1818
import sys
1919
from datetime import datetime, timedelta
2020
import logging
@@ -42,8 +42,13 @@
4242
objective_function,
4343
optimization,
4444
)
45-
from simpeg.utils import sdiag
46-
from simpeg.regularization import BaseRegularization, Sparse
45+
46+
from simpeg.regularization import (
47+
BaseRegularization,
48+
RegularizationMesh,
49+
Sparse,
50+
SparseSmoothness,
51+
)
4752

4853
from simpeg_drivers import DRIVER_MAP
4954
from simpeg_drivers.components import (
@@ -60,7 +65,7 @@
6065
)
6166
from simpeg_drivers.joint.params import BaseJointOptions
6267
from simpeg_drivers.utils.utils import tile_locations
63-
from simpeg_drivers.utils.regularization import cell_neighbors, rotated_gradient
68+
from simpeg_drivers.utils.regularization import cell_neighbors, set_rotated_operators
6469

6570
mlogger = logging.getLogger("distributed")
6671
mlogger.setLevel(logging.WARNING)
@@ -442,91 +447,101 @@ def get_regularization(self):
442447
return BaseRegularization(mesh=self.inversion_mesh.mesh)
443448

444449
reg_funcs = []
445-
total_reg_funcs = []
446450
is_rotated = self.params.gradient_rotation is not None
451+
neighbors = None
452+
backward_mesh = None
453+
forward_mesh = None
447454
for mapping in self.mapping:
448-
reg_funcs.append(
449-
Sparse(
450-
self.inversion_mesh.mesh,
451-
active_cells=self.models.active_cells,
452-
mapping=mapping,
453-
reference_model=self.models.reference,
454-
)
455+
reg_func = Sparse(
456+
forward_mesh or self.inversion_mesh.mesh,
457+
active_cells=self.models.active_cells if forward_mesh is None else None,
458+
mapping=mapping,
459+
reference_model=self.models.reference,
455460
)
456461

457-
if is_rotated:
458-
reg_funcs.append(
459-
Sparse(
460-
self.inversion_mesh.mesh,
461-
active_cells=self.models.active_cells,
462-
mapping=mapping,
463-
reference_model=self.models.reference,
464-
)
462+
if is_rotated and neighbors is None:
463+
backward_mesh = RegularizationMesh(
464+
self.inversion_mesh.mesh, active_cells=self.models.active_cells
465465
)
466-
neighbors = cell_neighbors(reg_funcs[0].regularization_mesh.mesh)
466+
neighbors = cell_neighbors(reg_func.regularization_mesh.mesh)
467467

468468
# Adjustment for 2D versus 3D problems
469-
comps = "sxz" if "2d" in self.params.inversion_type else "sxyz"
470-
avg_comps = "sxy" if "2d" in self.params.inversion_type else "sxyz"
471-
weights = ["alpha_s"] + [f"length_scale_{k}" for k in comps[1:]]
472-
473-
for comp, avg_comp, objfct, weight in zip(
474-
comps, avg_comps, reg_funcs[0].objfcts, weights
469+
components = "sxz" if "2d" in self.params.inversion_type else "sxyz"
470+
weight_names = ["alpha_s"] + [f"length_scale_{k}" for k in components[1:]]
471+
functions = []
472+
for comp, weight_name, fun in zip(
473+
components, weight_names, reg_func.objfcts
475474
):
476-
if getattr(self.models, weight) is None:
477-
setattr(reg_funcs[0], weight, 0.0)
475+
if getattr(self.models, weight_name) is None:
476+
setattr(reg_funcs, weight_name, 0.0)
477+
functions.append(fun)
478478
continue
479479

480-
weight = mapping * getattr(self.models, weight)
480+
weight = mapping * getattr(self.models, weight_name)
481481
norm = mapping * getattr(self.models, f"{comp}_norm")
482-
if comp in "xyz":
482+
483+
if isinstance(fun, SparseSmoothness):
483484
if is_rotated:
484-
for reg, forward in zip(reg_funcs, [True, False]):
485-
grad_op = rotated_gradient(
486-
mesh=reg.regularization_mesh.mesh,
487-
neighbors=neighbors,
488-
axis=comp,
489-
dip=self.models.gradient_dip,
490-
direction=self.models.gradient_direction,
491-
forward=forward,
492-
)
493-
setattr(
494-
reg.regularization_mesh,
495-
f"_cell_gradient_{comp}",
496-
reg.regularization_mesh.Pac.T
497-
@ (grad_op @ reg.regularization_mesh.Pac),
498-
)
499-
setattr(
500-
reg.regularization_mesh,
501-
f"_aveCC2F{avg_comp}",
502-
sdiag(np.ones(reg.regularization_mesh.n_cells)),
485+
if forward_mesh is None:
486+
fun = set_rotated_operators(
487+
fun,
488+
neighbors,
489+
comp,
490+
self.models.gradient_dip,
491+
self.models.gradient_direction,
503492
)
493+
504494
else:
505495
weight = (
506496
getattr(
507-
reg_funcs[0].regularization_mesh, f"aveCC2F{avg_comp}"
497+
reg_func.regularization_mesh,
498+
f"aveCC2F{fun.orientation}",
508499
)
509500
* weight
510501
)
511502
norm = (
512503
getattr(
513-
reg_funcs[0].regularization_mesh, f"aveCC2F{avg_comp}"
504+
reg_func.regularization_mesh,
505+
f"aveCC2F{fun.orientation}",
514506
)
515507
* norm
516508
)
517509

518-
objfct.set_weights(**{comp: weight})
519-
objfct.norm = norm
520-
521-
if getattr(self.params, "gradient_type") is not None:
522-
for reg in reg_funcs:
523-
setattr(
524-
reg,
525-
"gradient_type",
526-
getattr(self.params, "gradient_type"),
527-
)
528-
529-
total_reg_funcs += reg_funcs
510+
fun.set_weights(**{comp: weight})
511+
fun.norm = norm
512+
functions.append(fun)
513+
514+
if isinstance(fun, SparseSmoothness) and is_rotated:
515+
fun.gradient_type = "components"
516+
backward_fun = deepcopy(fun)
517+
setattr(backward_fun, "_regularization_mesh", backward_mesh)
518+
519+
# Only do it once for MVI
520+
if not forward_mesh:
521+
backward_fun = set_rotated_operators(
522+
backward_fun,
523+
neighbors,
524+
comp,
525+
self.models.gradient_dip,
526+
self.models.gradient_direction,
527+
forward=False,
528+
)
529+
functions.append(backward_fun)
530+
531+
# Will avoid recomputing operators if the regularization mesh is the same
532+
forward_mesh = reg_func.regularization_mesh
533+
reg_func.objfcts = functions
534+
reg_func.norms = [fun.norm for fun in functions]
535+
reg_funcs.append(reg_func)
536+
537+
# TODO - To be deprcated on GEOPY-2109
538+
if getattr(self.params, "gradient_type") is not None:
539+
for reg in reg_funcs:
540+
setattr(
541+
reg,
542+
"gradient_type",
543+
getattr(self.params, "gradient_type"),
544+
)
530545

531546
return objective_function.ComboObjectiveFunction(objfcts=reg_funcs)
532547

0 commit comments

Comments
 (0)