Skip to content

Commit c151e56

Browse files
authored
Merge branch 'release/0.3.0' into GEOPY-2193
2 parents dff5b3f + f910d89 commit c151e56

62 files changed

Lines changed: 13420 additions & 215 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ repos:
8686
- id: mixed-line-ending
8787
exclude: ^\.idea/.*\.xml$
8888
- id: name-tests-test
89+
exclude: testing_utils.py
8990
- id: pretty-format-json
9091
args:
9192
- --autofix

simpeg_drivers-assets/uijson/tdem_inversion.ui.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,8 @@
315315
],
316316
"label": "Gradient rotation",
317317
"parent": "mesh",
318+
"optional": true,
319+
"enabled": false,
318320
"value": ""
319321
},
320322
"s_norm": {

simpeg_drivers/components/factories/misfit_factory.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,19 @@ def assemble_arguments( # pylint: disable=arguments-differ
7979
for local_index in tiles:
8080
if len(local_index) == 0:
8181
continue
82-
local_mesh = None
82+
83+
local_sim, _, _, _ = self.create_nested_simulation(
84+
inversion_data,
85+
inversion_mesh,
86+
None,
87+
active_cells,
88+
local_index,
89+
channel=None,
90+
tile_id=tile_count,
91+
padding_cells=self.params.padding_cells,
92+
)
93+
94+
local_mesh = getattr(local_sim, "mesh", None)
8395

8496
for count, channel in enumerate(channels):
8597
n_split = split_list[misfit_count]
@@ -97,8 +109,6 @@ def assemble_arguments( # pylint: disable=arguments-differ
97109
)
98110
)
99111

100-
local_mesh = getattr(local_sim, "mesh", None)
101-
102112
if count == 0:
103113
if self.factory_type in [
104114
"fdem",

simpeg_drivers/driver.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -612,30 +612,38 @@ def configure_dask(self):
612612
dconf.set(scheduler="threads", pool=ThreadPool(n_cpu))
613613

614614
@classmethod
615-
def start(cls, filepath: str | Path | InputFile):
615+
def start(
616+
cls, filepath: str | Path | InputFile, driver_class=None, **kwargs
617+
) -> InversionDriver:
618+
"""
619+
Start the inversion driver.
620+
621+
:param filepath: Path to the input file or InputFile object.
622+
:param driver_class: Optional driver class to use instead of the default.
623+
:param kwargs: Additional keyword arguments for InputFile read_ui_json.
624+
625+
:return: InversionDriver instance with the specified parameters.
626+
"""
616627
if isinstance(filepath, InputFile):
617628
ifile = filepath
618629
else:
619-
ifile = InputFile.read_ui_json(filepath)
630+
ifile = InputFile.read_ui_json(filepath, **kwargs)
620631

621-
forward_only = ifile.data["forward_only"]
622-
inversion_type = ifile.ui_json.get("inversion_type", None)
623-
624-
driver_class = cls.driver_class_from_name(
625-
inversion_type, forward_only=forward_only
626-
)
627-
628-
with ifile.data["geoh5"].open(mode="r+"):
629-
params = driver_class._options_class.build(ifile)
630-
driver = driver_class(params)
632+
if driver_class is None:
633+
driver = cls.from_input_file(ifile)
634+
else:
635+
with ifile.data["geoh5"].open(mode="r+"):
636+
params = driver_class._options_class.build(ifile)
637+
driver = driver_class(params)
631638

632639
driver.run()
640+
633641
return driver
634642

635643
@staticmethod
636644
def driver_class_from_name(
637645
name: str, forward_only: bool = False
638-
) -> InversionDriver:
646+
) -> type[InversionDriver]:
639647
if name not in DRIVER_MAP:
640648
msg = f"Inversion type {name} is not supported."
641649
msg += f" Valid inversions are: {(*list(DRIVER_MAP),)}."
@@ -649,6 +657,26 @@ def driver_class_from_name(
649657
module = __import__(mod_name, fromlist=[class_name])
650658
return getattr(module, class_name)
651659

660+
@classmethod
661+
def from_input_file(cls, ifile: InputFile) -> InversionDriver:
662+
forward_only = ifile.data["forward_only"]
663+
inversion_type = ifile.ui_json.get("inversion_type", None)
664+
if inversion_type is None:
665+
raise GeoAppsError(
666+
"Key/value 'inversion_type' not found in the input file. "
667+
"Please specify the inversion type in the UI JSON."
668+
)
669+
670+
driver_class = cls.driver_class_from_name(
671+
inversion_type, forward_only=forward_only
672+
)
673+
674+
with ifile.data["geoh5"].open(mode="r+"):
675+
params = driver_class._options_class.build(ifile)
676+
driver = driver_class(params)
677+
678+
return driver
679+
652680

653681
class InversionLogger:
654682
def __init__(self, logfile, driver):

tests/data_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from simpeg_drivers.potential_fields.magnetic_vector.options import (
3030
MVIInversionOptions,
3131
)
32-
from simpeg_drivers.utils.testing import Geoh5Tester, setup_inversion_workspace
32+
from tests.testing_utils import Geoh5Tester, setup_inversion_workspace
3333

3434

3535
def get_mvi_params(tmp_path: Path, **kwargs) -> MVIInversionOptions:

tests/driver_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from simpeg_drivers.options import ActiveCellsOptions
1616
from simpeg_drivers.potential_fields import GravityInversionOptions
1717
from simpeg_drivers.potential_fields.gravity.driver import GravityInversionDriver
18-
from simpeg_drivers.utils.testing import setup_inversion_workspace
18+
from tests.testing_utils import setup_inversion_workspace
1919

2020

2121
def test_smallness_terms(tmp_path: Path):

0 commit comments

Comments
 (0)