Skip to content

Commit aa26839

Browse files
committed
Change mechanics for mapping models in joint surveys. Add unit test
1 parent 253f2fd commit aa26839

4 files changed

Lines changed: 107 additions & 10 deletions

File tree

simpeg_drivers/joint/driver.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def initialize(self):
129129
self.models.active_cells = global_actives
130130
for driver, wire in zip(self.drivers, self.wires, strict=True):
131131
logger.info("Initializing driver %s", driver.params.name)
132+
# Create a projection from global mesh to driver specific mesh
132133
projection = TileMap(
133134
self.inversion_mesh.mesh,
134135
global_actives,
@@ -139,6 +140,8 @@ def initialize(self):
139140
tile_map = projection * wire
140141
driver.params.active_model = None
141142
driver.models.active_cells = projection.local_active
143+
144+
# Keep a copy on the top combo/future for saving directives and model creation
142145
driver.data_misfit.model_map = tile_map
143146

144147
multipliers = []

simpeg_drivers/joint/joint_surveys/driver.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,25 @@ def __init__(self, params: JointSurveysOptions):
4040

4141
def validate_create_models(self):
4242
"""Check if all models were provided, otherwise use the first driver models."""
43+
# Create projection for first driver to global mesh
44+
mapping = maps.TileMap(
45+
self.inversion_mesh.mesh,
46+
self.models.active_cells,
47+
self.drivers[0].inversion_mesh.mesh,
48+
enforce_active=False,
49+
)
50+
projection = mapping.deriv(np.ones(self.models.n_active)).T
51+
norm = np.array(np.sum(projection, axis=1)).flatten()
52+
4353
for model_type in self.models.model_types:
4454
model = getattr(self.models, model_type)
4555
if model is not None or getattr(self.drivers[0].models, model_type) is None:
4656
continue
4757

4858
model_local_values = getattr(self.drivers[0].models, model_type)
49-
projection = (
50-
self.drivers[0]
51-
.data_misfit.model_map.deriv(np.ones(self.models.n_active))
52-
.T
53-
)
54-
norm = np.array(np.sum(projection, axis=1)).flatten()
55-
model = (projection * model_local_values) / (norm + 1e-8)
59+
model = (
60+
projection * model_local_values[: self.drivers[0].models.n_active]
61+
) / (norm + 1e-8)
5662

5763
if self.drivers[0].models.is_sigma and model_type in [
5864
"starting_model",
@@ -70,6 +76,9 @@ def validate_create_models(self):
7076

7177
getattr(self.models, f"_{model_type}").model = model
7278

79+
# For MVI, set is_vector from first driver
80+
self.models.is_vector = self.drivers[0].models.is_vector
81+
7382
@property
7483
def wires(self):
7584
"""Model projections"""

simpeg_drivers/utils/synthetics/driver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def mesh(self):
8686

8787
@property
8888
def active(self):
89-
entity = self.geoh5.get_entity(self.options.active.name)[0]
89+
entity = self.mesh.get_entity(self.options.active.name)[0]
9090
assert isinstance(entity, FloatData | type(None))
9191
if entity is None:
9292
entity = get_active(self.mesh, self.topography)
@@ -95,7 +95,7 @@ def active(self):
9595

9696
@property
9797
def model(self):
98-
entity = self.geoh5.get_entity(self.options.model.name)[0]
98+
entity = self.mesh.get_entity(self.options.model.name)[0]
9999
assert isinstance(entity, FloatData | type(None))
100100
if entity is None:
101101
assert self.options is not None

tests/run_tests/driver_joint_surveys_test.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import numpy as np
1414
from geoh5py.objects import Octree
1515
from geoh5py.workspace import Workspace
16-
from simpeg.directives import SavePropertyGroup
16+
from simpeg.directives import SaveModelGeoH5, SavePropertyGroup
1717

1818
from simpeg_drivers.electricals.direct_current.three_dimensions.driver import (
1919
DC3DInversionDriver,
@@ -31,6 +31,8 @@
3131
GravityInversionOptions,
3232
)
3333
from simpeg_drivers.potential_fields.gravity.driver import GravityInversionDriver
34+
from simpeg_drivers.potential_fields.magnetic_vector.driver import MVIInversionDriver
35+
from simpeg_drivers.potential_fields.magnetic_vector.options import MVIInversionOptions
3436
from simpeg_drivers.utils.synthetics.driver import (
3537
SyntheticsComponents,
3638
)
@@ -195,6 +197,89 @@ def test_joint_surveys_inv_run(
195197
check_target(output, target_run)
196198

197199

200+
def test_joint_surveys_mvi_run(tmp_path, anomaly=0.05):
201+
drivers = []
202+
203+
with Workspace.create(tmp_path / f"{__name__}.geoh5") as geoh5:
204+
for ii in range(1, 3):
205+
opts = SyntheticsComponentsOptions(
206+
method="magnetic_vector",
207+
survey=SurveyOptions(
208+
n_stations=3**ii,
209+
n_lines=3**ii,
210+
drape=5.0,
211+
name=f"Survey Driver[{ii}]",
212+
),
213+
mesh=MeshOptions(refinement=(2**ii, 2, 2), name=f"Mesh Driver[{ii}]"),
214+
model=ModelOptions(anomaly=anomaly),
215+
)
216+
components = SyntheticsComponents(geoh5, options=opts)
217+
survey = components.survey
218+
obs, uncrt = survey.add_data(
219+
{
220+
"TMI": {"values": np.random.randn(survey.n_vertices)},
221+
"Uncertainty": {"values": np.ones(survey.n_vertices) * 1e-3},
222+
}
223+
)
224+
225+
# Add an inclination model on the first driver only to test handling of
226+
# models from the main driver
227+
if ii == 1:
228+
model = components.model.values
229+
model[model > 0] = 45.0
230+
model[model <= 0] = 90.0
231+
inc_mod = components.mesh.add_data(
232+
{"Inclination Model": {"values": model}}
233+
)
234+
else:
235+
inc_mod = None
236+
237+
params = MVIInversionOptions.build(
238+
geoh5=geoh5,
239+
mesh=components.mesh,
240+
topography_object=components.topography,
241+
tmi_channel=obs,
242+
tmi_uncertainty=uncrt,
243+
inducing_field_strength=45000,
244+
inducing_field_inclination=90.0,
245+
inducing_field_declination=0.0,
246+
data_object=survey,
247+
starting_model=components.model,
248+
starting_inclination=inc_mod,
249+
reference_model=0.0,
250+
)
251+
drivers.append(MVIInversionDriver(params))
252+
253+
# Run the inverse
254+
joint_params = JointSurveysOptions.build(
255+
geoh5=geoh5,
256+
active_cells=ActiveCellsOptions(topography_object=components.topography),
257+
group_a=drivers[0].out_group,
258+
group_b=drivers[1].out_group,
259+
starting_model=0.01,
260+
# Default to Conductivity (S/m)
261+
)
262+
263+
driver = JointSurveyDriver(joint_params)
264+
assert np.isclose(driver.models.reference_model[0], 0) # Took it from driver_A
265+
assert driver.models.starting_model.shape == (driver.models.n_active * 3,)
266+
assert np.isclose(
267+
driver.models.starting_model.max(), 0.01 * np.cos(np.deg2rad(45.0))
268+
)
269+
270+
# Test saving the starting models on each mesh (open file to validate)
271+
assert (
272+
len(
273+
[
274+
directive.write(0, driver.models.starting_model)
275+
for directive in driver.directives.directive_list
276+
if isinstance(directive, SaveModelGeoH5)
277+
]
278+
)
279+
== 3
280+
)
281+
282+
198283
def test_joint_surveys_conductivity_run(
199284
tmp_path,
200285
):

0 commit comments

Comments
 (0)