Skip to content

Commit 155e6f0

Browse files
authored
Merge pull request #373 from MiraGeoscience/GEOPY-2801
GEOPY-2801: PGI: Add checks for reference model present on inversion groups. Return clean GeoAppsError if not
2 parents b3e299d + 7ce2325 commit 155e6f0

4 files changed

Lines changed: 61 additions & 18 deletions

File tree

simpeg_drivers/joint/joint_cross_gradient/driver.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313

1414
from __future__ import annotations
1515

16+
import sys
1617
from itertools import combinations
18+
from pathlib import Path
1719

1820
from geoh5py.shared.utils import fetch_active_workspace
1921
from simpeg import maps
@@ -71,3 +73,8 @@ def get_regularization(self):
7173
)
7274

7375
return ComboObjectiveFunction(objfcts=reg_list, multipliers=multipliers)
76+
77+
78+
if __name__ == "__main__":
79+
file = Path(sys.argv[1]).resolve()
80+
JointCrossGradientDriver.start_dask_run(file)

simpeg_drivers/joint/joint_petrophysics/driver.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010

1111
from __future__ import annotations
1212

13+
import sys
14+
from pathlib import Path
15+
1316
import numpy as np
17+
from geoapps_utils.utils.importing import GeoAppsError
1418
from geoh5py.shared.utils import fetch_active_workspace
1519
from simpeg import directives, maps, utils
1620
from simpeg.objective_function import ComboObjectiveFunction
@@ -158,6 +162,12 @@ def means(self) -> np.ndarray:
158162
"""
159163
means = []
160164
for mapping in self.mapping:
165+
if self.models.reference_model is None:
166+
raise GeoAppsError(
167+
"A reference model must be set and active on each inversion driver "
168+
"to determine the means of the Gaussian mixture model.\n"
169+
"Please revise the input options of individual drivers."
170+
)
161171
model_vec = mapping @ self.models.reference_model
162172
unit_mean = []
163173
for uid in self.geo_units:
@@ -220,3 +230,8 @@ def _overload_regularization(self, regularization: ComboObjectiveFunction):
220230
reg.alpha_s = 0.0
221231

222232
return reg_list, multipliers
233+
234+
235+
if __name__ == "__main__":
236+
file = Path(sys.argv[1]).resolve()
237+
JointPetrophysicsDriver.start_dask_run(file)

simpeg_drivers/joint/joint_surveys/driver.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111

1212
from __future__ import annotations
1313

14+
import sys
1415
from logging import getLogger
16+
from pathlib import Path
1517

1618
import numpy as np
1719
from geoh5py.shared.utils import fetch_active_workspace
@@ -144,3 +146,7 @@ def directives(self):
144146

145147
JointSurveysDriver.n_values = InversionDriver.n_values
146148
JointSurveysDriver.mapping = InversionDriver.mapping
149+
150+
if __name__ == "__main__":
151+
file = Path(sys.argv[1]).resolve()
152+
JointSurveysDriver.start_dask_run(file)

tests/run_tests/driver_joint_pgi_homogeneous_test.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from pathlib import Path
1414

1515
import numpy as np
16+
import pytest
17+
from geoapps_utils import GeoAppsError
1618
from geoh5py.groups.property_group import GroupTypeEnum, PropertyGroup
1719
from geoh5py.objects import Octree, Points
1820
from geoh5py.workspace import Workspace
@@ -145,10 +147,10 @@ def test_homogeneous_fwr_run(
145147
def test_homogeneous_run(
146148
tmp_path: Path,
147149
max_iterations=1,
148-
pytest=True,
150+
use_pytest=True,
149151
):
150152
workpath = tmp_path / "inversion_test.ui.geoh5"
151-
if pytest:
153+
if use_pytest:
152154
workpath = (
153155
tmp_path.parent / "test_homogeneous_fwr_run0" / "inversion_test.ui.geoh5"
154156
)
@@ -228,27 +230,28 @@ def test_homogeneous_run(
228230
inducing_field_declination=INDUCING_FIELD[2],
229231
data_object=survey,
230232
starting_model=ref_model,
231-
reference_model=ref_model,
233+
reference_model=None,
232234
tile_spatial=1,
233235
tmi_channel=data,
234236
tmi_uncertainty=5e0,
235237
)
236238
drivers.append(MagneticInversionDriver(params))
237239

238-
# Test if single group is valid
239-
params = JointPetrophysicsOptions.build(
240-
topography_object=topography,
241-
geoh5=geoh5,
242-
group_a=drivers[0].out_group,
243-
mesh=global_mesh,
244-
petrophysical_model=petrophysics,
245-
)
246-
driver = JointPetrophysicsDriver(params)
247-
assert len(driver.data_misfit.objfcts) == 1
248-
assert driver.data_misfit.multipliers == [1.0]
240+
if len(drivers) == 1:
241+
# Test if single group is valid
242+
params = JointPetrophysicsOptions.build(
243+
topography_object=topography,
244+
geoh5=geoh5,
245+
group_a=drivers[0].out_group,
246+
mesh=global_mesh,
247+
petrophysical_model=petrophysics,
248+
)
249+
driver = JointPetrophysicsDriver(params)
250+
assert len(driver.data_misfit.objfcts) == 1
251+
assert driver.data_misfit.multipliers == [1.0]
249252

250253
# Re-build full
251-
params = JointPetrophysicsOptions.build(
254+
joint_params = JointPetrophysicsOptions.build(
252255
topography_object=topography,
253256
geoh5=geoh5,
254257
group_a=drivers[0].out_group,
@@ -264,10 +267,22 @@ def test_homogeneous_run(
264267
initial_beta_ratio=1e2,
265268
max_global_iterations=max_iterations,
266269
)
267-
driver = JointPetrophysicsDriver(params)
270+
driver = JointPetrophysicsDriver(joint_params)
271+
272+
with pytest.raises(
273+
GeoAppsError, match="A reference model must be set and active on each"
274+
):
275+
_ = driver.means
276+
277+
# Re-instate
278+
params.models.reference_model = ref_model
279+
params.out_group = None
280+
new_driver = MagneticInversionDriver(params)
281+
joint_params.group_b = new_driver.out_group
282+
driver = JointPetrophysicsDriver(joint_params)
268283
driver.run()
269284

270-
if pytest:
285+
if use_pytest:
271286
with Workspace(driver.params.geoh5.h5file) as run_ws:
272287
output = get_inversion_output(
273288
driver.params.geoh5.h5file, driver.out_group.uid
@@ -293,5 +308,5 @@ def test_homogeneous_run(
293308
test_homogeneous_run(
294309
Path("./"),
295310
max_iterations=20,
296-
pytest=False,
311+
use_pytest=False,
297312
)

0 commit comments

Comments
 (0)