Skip to content

Commit 07123ac

Browse files
authored
Merge pull request #319 from MiraGeoscience/GEOPY-2620
GEOPY-2620: Joint surveys using mvi crashes with dimension mismatch
2 parents 069fdb4 + 7b2a06a commit 07123ac

5 files changed

Lines changed: 113 additions & 26 deletions

File tree

simpeg_drivers/joint/driver.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def initialize(self):
130130
self.models.active_cells = global_actives
131131
for driver, wire in zip(self.drivers, self.wires, strict=True):
132132
logger.info("Initializing driver %s", driver.params.name)
133+
# Create a projection from global mesh to driver specific mesh
133134
projection = TileMap(
134135
self.inversion_mesh.mesh,
135136
global_actives,
@@ -140,6 +141,8 @@ def initialize(self):
140141
tile_map = projection * wire
141142
driver.params.active_model = None
142143
driver.models.active_cells = projection.local_active
144+
145+
# Keep a copy on the top combo/future for saving directives and model creation
143146
driver.data_misfit.model_map = tile_map
144147

145148
multipliers = []

simpeg_drivers/joint/joint_surveys/driver.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,25 @@ def __init__(self, params: JointSurveysOptions):
4040

4141
def validate_create_models(self):
4242
"""Check if all models were provided, otherwise use the first driver models."""
43+
# Create projection for first driver to global mesh
44+
mapping = maps.TileMap(
45+
self.inversion_mesh.mesh,
46+
self.models.active_cells,
47+
self.drivers[0].inversion_mesh.mesh,
48+
enforce_active=False,
49+
)
50+
projection = mapping.deriv(np.ones(self.models.n_active)).T
51+
norm = np.array(np.sum(projection, axis=1)).flatten()
52+
4353
for model_type in self.models.model_types:
4454
model = getattr(self.models, model_type)
4555
if model is not None or getattr(self.drivers[0].models, model_type) is None:
4656
continue
4757

4858
model_local_values = getattr(self.drivers[0].models, model_type)
49-
projection = (
50-
self.drivers[0]
51-
.data_misfit.model_map.deriv(np.ones(self.models.n_active))
52-
.T
53-
)
54-
norm = np.array(np.sum(projection, axis=1)).flatten()
55-
model = (projection * model_local_values) / (norm + 1e-8)
59+
model = (
60+
projection * model_local_values[: self.drivers[0].models.n_active]
61+
) / (norm + 1e-8)
5662

5763
if self.drivers[0].models.is_sigma and model_type in [
5864
"starting_model",
@@ -70,11 +76,17 @@ def validate_create_models(self):
7076

7177
getattr(self.models, f"_{model_type}").model = model
7278

79+
# For MVI, set is_vector from first driver
80+
self.models.is_vector = self.drivers[0].models.is_vector
81+
7382
@property
7483
def wires(self):
7584
"""Model projections"""
7685
if self._wires is None:
77-
wires = [maps.IdentityMap(nP=self.models.n_active) for _ in self.drivers]
86+
wires = [
87+
maps.IdentityMap(nP=self.models.n_active * driver.n_blocks)
88+
for driver in self.drivers
89+
]
7890
self._wires = wires
7991

8092
return self._wires

simpeg_drivers/joint/joint_surveys/options.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,6 @@ class JointSurveysOptions(BaseJointOptions):
4949

5050
models: JointSurveysModelOptions
5151

52-
@field_validator("group_a", "group_b", "group_c")
53-
@classmethod
54-
def no_mvi_groups(cls, val):
55-
if val is None:
56-
return val
57-
58-
if "magnetic vector" in val.options.get("inversion_type", ""):
59-
raise ValueError(
60-
f"Joint inversion doesn't currently support MVI data as passed in "
61-
f"the group: {val.name}."
62-
)
63-
return val
64-
6552
@model_validator(mode="after")
6653
def all_groups_same_physical_property(self):
6754
physical_properties = [k.options["physical_property"] for k in self.groups]

simpeg_drivers/utils/synthetics/driver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import numpy as np
1212
from geoh5py import Workspace
13-
from geoh5py.data import FloatData
13+
from geoh5py.data import BooleanData, FloatData
1414
from geoh5py.objects import DrapeModel, ObjectBase, Octree, Surface
1515

1616
from simpeg_drivers.utils.synthetics.meshes.factory import get_mesh
@@ -86,16 +86,16 @@ def mesh(self):
8686

8787
@property
8888
def active(self):
89-
entity = self.geoh5.get_entity(self.options.active.name)[0]
90-
assert isinstance(entity, FloatData | type(None))
89+
entity = self.mesh.get_entity(self.options.active.name)[0]
90+
assert isinstance(entity, BooleanData | type(None))
9191
if entity is None:
9292
entity = get_active(self.mesh, self.topography)
9393
self._active = entity
9494
return self._active
9595

9696
@property
9797
def model(self):
98-
entity = self.geoh5.get_entity(self.options.model.name)[0]
98+
entity = self.mesh.get_entity(self.options.model.name)[0]
9999
assert isinstance(entity, FloatData | type(None))
100100
if entity is None:
101101
assert self.options is not None

tests/run_tests/driver_joint_surveys_test.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import numpy as np
1414
from geoh5py.objects import Octree
1515
from geoh5py.workspace import Workspace
16-
from simpeg.directives import SavePropertyGroup
16+
from simpeg.directives import SaveModelGeoH5, SavePropertyGroup
1717

1818
from simpeg_drivers.electricals.direct_current.three_dimensions.driver import (
1919
DC3DInversionDriver,
@@ -31,6 +31,8 @@
3131
GravityInversionOptions,
3232
)
3333
from simpeg_drivers.potential_fields.gravity.driver import GravityInversionDriver
34+
from simpeg_drivers.potential_fields.magnetic_vector.driver import MVIInversionDriver
35+
from simpeg_drivers.potential_fields.magnetic_vector.options import MVIInversionOptions
3436
from simpeg_drivers.utils.synthetics.driver import (
3537
SyntheticsComponents,
3638
)
@@ -195,6 +197,89 @@ def test_joint_surveys_inv_run(
195197
check_target(output, target_run)
196198

197199

200+
def test_joint_surveys_mvi_run(tmp_path, anomaly=0.05):
201+
drivers = []
202+
203+
with Workspace.create(tmp_path / f"{__name__}.geoh5") as geoh5:
204+
for ii in range(1, 3):
205+
opts = SyntheticsComponentsOptions(
206+
method="magnetic_vector",
207+
survey=SurveyOptions(
208+
n_stations=3**ii,
209+
n_lines=3**ii,
210+
drape=5.0,
211+
name=f"Survey Driver[{ii}]",
212+
),
213+
mesh=MeshOptions(refinement=(2**ii, 2, 2), name=f"Mesh Driver[{ii}]"),
214+
model=ModelOptions(anomaly=anomaly),
215+
)
216+
components = SyntheticsComponents(geoh5, options=opts)
217+
survey = components.survey
218+
obs, uncrt = survey.add_data(
219+
{
220+
"TMI": {"values": np.random.randn(survey.n_vertices)},
221+
"Uncertainty": {"values": np.ones(survey.n_vertices) * 1e-3},
222+
}
223+
)
224+
225+
# Add an inclination model on the first driver only to test handling of
226+
# models from the main driver
227+
if ii == 1:
228+
model = components.model.values
229+
model[model > 0] = 45.0
230+
model[model <= 0] = 90.0
231+
inc_mod = components.mesh.add_data(
232+
{"Inclination Model": {"values": model}}
233+
)
234+
else:
235+
inc_mod = None
236+
237+
params = MVIInversionOptions.build(
238+
geoh5=geoh5,
239+
mesh=components.mesh,
240+
topography_object=components.topography,
241+
tmi_channel=obs,
242+
tmi_uncertainty=uncrt,
243+
inducing_field_strength=45000,
244+
inducing_field_inclination=90.0,
245+
inducing_field_declination=0.0,
246+
data_object=survey,
247+
starting_model=components.model,
248+
starting_inclination=inc_mod,
249+
reference_model=0.0,
250+
)
251+
drivers.append(MVIInversionDriver(params))
252+
253+
# Run the inverse
254+
joint_params = JointSurveysOptions.build(
255+
geoh5=geoh5,
256+
active_cells=ActiveCellsOptions(topography_object=components.topography),
257+
group_a=drivers[0].out_group,
258+
group_b=drivers[1].out_group,
259+
starting_model=0.01,
260+
# Default to Conductivity (S/m)
261+
)
262+
263+
driver = JointSurveyDriver(joint_params)
264+
assert np.isclose(driver.models.reference_model[0], 0) # Took it from driver_A
265+
assert driver.models.starting_model.shape == (driver.models.n_active * 3,)
266+
assert np.isclose(
267+
driver.models.starting_model.max(), 0.01 * np.cos(np.deg2rad(45.0))
268+
)
269+
270+
# Test saving the starting models on each mesh (open file to validate)
271+
assert (
272+
len(
273+
[
274+
directive.write(0, driver.models.starting_model)
275+
for directive in driver.directives.directive_list
276+
if isinstance(directive, SaveModelGeoH5)
277+
]
278+
)
279+
== 3
280+
)
281+
282+
198283
def test_joint_surveys_conductivity_run(
199284
tmp_path,
200285
):

0 commit comments

Comments
 (0)