Skip to content

Commit 146ef54

Browse files
committed
Use params for conversion of line ids to parts.
1 parent 8ae9c50 commit 146ef54

7 files changed

Lines changed: 57 additions & 62 deletions

File tree

simpeg_drivers/components/data.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -90,23 +90,10 @@ def __init__(self, workspace: Workspace, params: InversionBaseOptions):
9090
def _initialize(self) -> None:
9191
"""Extract data from the workspace using params data."""
9292
self.components = self.params.active_components
93-
9493
self.has_tensor = InversionData.check_tensor(self.params.components)
9594
self.locations = super().get_locations(self.params.data_object)
96-
97-
if (
98-
"2d" in self.params.inversion_type
99-
and self.params.line_selection.line_id is not None
100-
):
101-
self.mask = (
102-
self.params.line_selection.line_object.values
103-
== self.params.line_selection.line_id
104-
)
105-
else:
106-
self.mask = np.ones(len(self.locations), dtype=bool)
107-
95+
self.mask = np.ones(len(self.locations), dtype=bool)
10896
self.normalizations: dict[str, Any] = self.get_normalizations()
109-
11097
self.entity = self.write_entity()
11198

11299
self.save_data()
@@ -256,6 +243,15 @@ def save_data(self):
256243
self._observed_data_types = data_types
257244
self.update_params(data_dict, uncert_dict)
258245

246+
if (
247+
getattr(self.params, "line_selection", None) is not None
248+
and self.params.line_selection.property is not None
249+
):
250+
self.params.line_selection.property.copy(
251+
parent=self.entity,
252+
values=self.params.line_selection.property.values[self.mask],
253+
)
254+
259255
def normalize(
260256
self, data: dict[str, np.ndarray], absolute=False
261257
) -> dict[str, np.ndarray]:
@@ -358,9 +354,8 @@ def create_survey(self):
358354
survey.cells = self.entity.cells
359355

360356
if "2d" in self.params.inversion_type:
361-
survey.line_ids = self.params.line_selection.line_object.values[
362-
survey_factory.sorting
363-
]
357+
# Assign line id with sequential numbering to mirror the drape mesh parts
358+
survey.line_ids = self.params.line_parts[survey_factory.sorting]
364359

365360
return survey
366361

@@ -395,18 +390,6 @@ def update_params(self, data_dict, uncert_dict):
395390
setattr(self.params, f"{comp}_channel", data_dict[comp])
396391
setattr(self.params, f"{comp}_uncertainty", uncert_dict[comp])
397392

398-
if getattr(self.params, "line_selection", None) is not None:
399-
if self.params.line_selection.line_object is None:
400-
parts = get_parts_from_electrodes(self.entity)
401-
line_ids = self.entity.add_data({"Line IDs": {"values": parts + 1}})
402-
else:
403-
line_ids = self.params.line_selection.line_object.copy(
404-
parent=self.entity,
405-
values=self.params.line_selection.line_object.values[self.mask],
406-
)
407-
408-
self.params.line_selection.line_object = line_ids
409-
410393
@property
411394
def survey(self):
412395
if self._survey is None:

simpeg_drivers/components/factories/directives_factory.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import numpy as np
2222
from geoh5py.groups.property_group import GroupTypeEnum
23+
from geoh5py.objects import PotentialElectrode
2324
from numpy import sqrt
2425
from simpeg import directives, maps
2526
from simpeg.utils.mat_utils import cartesian2amplitude_dip_azimuth
@@ -513,10 +514,14 @@ def assemble_keyword_arguments(
513514
]
514515
components = list(inversion_object.observed)
515516
ordering = inversion_object.survey.ordering
516-
n_locations = len(np.unique(ordering[:, 2]))
517+
518+
if isinstance(receivers, PotentialElectrode):
519+
n_locations = receivers.n_cells
520+
else:
521+
n_locations = receivers.n_vertices
517522

518523
def reshape(values):
519-
data = np.zeros((len(channels), len(components), n_locations))
524+
data = np.full((len(channels), len(components), n_locations), np.nan)
520525
data[ordering[:, 0], ordering[:, 1], ordering[:, 2]] = values
521526
return data
522527

simpeg_drivers/options.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,10 @@ class LineSelectionOptions(BaseModel):
504504
)
505505
line_id: int | None = None
506506
line_object: IntegerData | ReferencedData | None = None
507+
property: ReferencedData | None = None
508+
value: list[int] | None = None
507509

508-
@field_validator("line_object", mode="before")
510+
@field_validator("property", mode="before")
509511
@classmethod
510512
def validate_cell_association(cls, value):
511513
if value and value.association is not DataAssociationEnum.CELL:
@@ -514,8 +516,17 @@ def validate_cell_association(cls, value):
514516

515517
@model_validator(mode="after")
516518
def line_id_referenced(self):
517-
if self.line_id is not None and self.line_id not in self.line_object.values:
518-
raise ValueError("Line id isn't referenced in the line object.")
519+
if self.line_object is not None:
520+
logger.warning(
521+
"Running with an older version of DC inversion 2D.\n"
522+
"Please update to version 0.5.0 or later to ensure line selection is properly applied.\n"
523+
"Results may be affected.",
524+
)
525+
self.property = self.line_object
526+
527+
if isinstance(self.line_id, int):
528+
self.value = [self.line_id]
529+
519530
return self
520531

521532

simpeg_drivers/utils/nested.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -308,16 +308,7 @@ def create_simulation(
308308
# the local active cells
309309
else:
310310
# Map the line_ids to the mesh parts (assumes sequential numbering)
311-
line_number = (
312-
np.where(
313-
np.isin(
314-
np.unique(simulation.survey.line_ids),
315-
np.unique(local_survey.line_ids),
316-
)
317-
)[0]
318-
+ 1
319-
)
320-
311+
line_number = np.unique(local_survey.line_ids)
321312
active_mesh_part = np.isin(simulation.mesh.parts, line_number)
322313
n_actives = simulation.active_cells.sum()
323314
activate_ind = np.zeros(simulation.mesh.n_cells, dtype=int)

simpeg_drivers/utils/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def drape_2_tensor(drape_model: DrapeModel, return_sorting: bool = False) -> tup
306306

307307
# Skip indices for ghost points
308308
count = -1
309-
part = 1
309+
part = 0
310310
parts = []
311311
cell_widths = []
312312
section = []
@@ -406,7 +406,9 @@ def get_drape_model(
406406
min_locs = locations.min(axis=0)
407407
max_locs = locations.max(axis=0)
408408
xyz_smooth -= xyz_smooth.min(axis=0)[None, :]
409-
xyz_smooth *= ((max_locs - min_locs) / xyz_smooth.max(axis=0))[None, :]
409+
xyz_smooth *= ((max_locs - min_locs) / np.maximum(xyz_smooth.max(axis=0), 1e-3))[
410+
None, :
411+
]
410412
xyz_smooth += min_locs[None, :]
411413

412414
distances = compute_alongline_distance(xyz_smooth)

tests/run_tests/driver_dc_2d_test.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,17 @@ def test_dc_2d_fwr_run(
7676
),
7777
),
7878
)
79+
7980
with get_workspace(tmp_path / "inversion_test.ui.geoh5") as geoh5:
8081
components = SyntheticsComponents(geoh5=geoh5, options=opts)
82+
line_ids = components.survey.get_data("line_ids")[0]
8183
params = DC2DForwardOptions.build(
8284
geoh5=geoh5,
8385
mesh=components.mesh,
8486
topography_object=components.topography,
8587
data_object=components.survey,
8688
starting_model=components.model,
87-
line_selection=LineSelectionOptions(),
89+
line_selection=LineSelectionOptions(property=line_ids, value=[1, 101, 201]),
8890
)
8991
fwr_driver = DC2DForwardDriver(params)
9092
fwr_driver.run()
@@ -119,9 +121,6 @@ def test_dc_2d_run(
119121
data_object=potential.parent,
120122
potential_channel=potential,
121123
potential_uncertainty=uncertainties,
122-
line_selection=LineSelectionOptions(
123-
line_object=potential.parent.get_entity("Line IDs")[0]
124-
),
125124
starting_model=1e-3,
126125
reference_model=1e-3,
127126
s_norm=0.0,
@@ -134,10 +133,11 @@ def test_dc_2d_run(
134133
upper_bound=10,
135134
cooling_rate=1,
136135
)
137-
params.write_ui_json(path=tmp_path / "Inv_run.ui.json")
138-
139-
driver = DC2DInversionDriver.start(str(tmp_path / "Inv_run.ui.json"))
136+
# TODO Fix the write for MultiSelect of Reference data
137+
# params.write_ui_json(path=tmp_path / "Inv_run.ui.json")
140138

139+
driver = DC2DInversionDriver(params)
140+
driver.run()
141141
output = get_inversion_output(
142142
driver.params.geoh5.h5file, driver.params.out_group.uid
143143
)
@@ -168,6 +168,9 @@ def test_dc_single_run(
168168
}
169169
}
170170
)
171+
172+
line_ids = survey.get_data("line_ids")[0]
173+
171174
# Run the inverse
172175
params = DC2DInversionOptions.build(
173176
geoh5=geoh5,
@@ -177,9 +180,7 @@ def test_dc_single_run(
177180
data_object=potential.parent,
178181
potential_channel=potential,
179182
potential_uncertainty=uncertainties,
180-
line_selection=LineSelectionOptions(
181-
line_object=potential.parent.get_entity("Line IDs")[0], line_id=2
182-
),
183+
line_selection=LineSelectionOptions(property=line_ids, value=[101]),
183184
starting_model=1e-3,
184185
reference_model=1e-3,
185186
s_norm=0.0,
@@ -194,16 +195,17 @@ def test_dc_single_run(
194195
)
195196
params.write_ui_json(path=tmp_path / "Inv_run.ui.json")
196197

197-
DC2DInversionDriver.start(str(tmp_path / "Inv_run.ui.json"))
198+
driver = DC2DInversionDriver(params)
199+
driver.run()
198200

199201
with Workspace(workpath) as geoh5:
200202
inv_group = geoh5.get_entity("Direct Current Single 2D Inversion")[0]
201203
mesh = inv_group.get_entity("mesh")[0]
202204
model = mesh.get_entity("Iteration_1_model")[0]
203205

204206
# Check that model values for lines 1 and 3 are close to the starting model (1e-3) and that line 2 has been updated.
205-
np.testing.assert_almost_equal(np.nanmin(model.values[:2369]), 1e-3, decimal=3)
206-
np.testing.assert_almost_equal(np.nanmin(model.values[-2368:]), 1e-3, decimal=3)
207+
np.testing.assert_almost_equal(np.nanmax(model.values[:2369]), 1e-3, decimal=3)
208+
np.testing.assert_almost_equal(np.nanmax(model.values[-2368:]), 1e-3, decimal=3)
207209
assert np.nanmax(model.values[2368:-2368]) > 1e-3
208210

209211

tests/run_tests/driver_ip_2d_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,11 @@ def test_ip_2d_run(
127127
upper_bound=0.1,
128128
cooling_rate=1,
129129
)
130-
params.write_ui_json(path=tmp_path / "Inv_run.ui.json")
131-
132-
driver = IP2DInversionDriver.start(str(tmp_path / "Inv_run.ui.json"))
130+
# TODO Fix the write out with Multiselect of ReferenceData values
131+
# params.write_ui_json(path=tmp_path / "Inv_run.ui.json")
133132

133+
driver = IP2DInversionDriver(params)
134+
driver.run()
134135
output = get_inversion_output(
135136
driver.params.geoh5.h5file, driver.params.out_group.uid
136137
)

0 commit comments

Comments
 (0)