Skip to content

Commit 17ca343

Browse files
authored
Merge pull request #344 from MiraGeoscience/GEOPY-2622
GEOPY-2622: Clip limits of topography extent based on mesh extent
2 parents 392ece3 + 72e44f5 commit 17ca343

3 files changed

Lines changed: 138 additions & 5 deletions

File tree

simpeg_drivers/components/topography.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
floating_active,
3939
get_containing_cells,
4040
get_neighbouring_cells,
41+
mask_vertices_and_cells,
42+
octree_extents,
4143
)
4244

4345

@@ -101,14 +103,26 @@ def active_cells(self, mesh: InversionMesh, data: InversionData) -> np.ndarray:
101103
active_cells = InversionModel.obj_2_mesh(
102104
self.params.active_cells.active_model, mesh.entity
103105
)
106+
104107
else:
108+
if any(k in self.params.inversion_type for k in ["2d", "p3d"]):
109+
vertices = self.locations
110+
cells = getattr(
111+
self.params.active_cells.topography_object, "cells", None
112+
)
113+
else:
114+
extent = octree_extents(mesh.entity)[:4]
115+
vertices, cells = mask_vertices_and_cells(
116+
extent.ravel(order="F"),
117+
self.locations,
118+
getattr(self.params.active_cells.topography_object, "cells", None),
119+
)
120+
105121
active_cells = active_from_xyz(
106122
mesh.entity,
107-
self.locations,
123+
vertices,
108124
grid_reference="bottom" if forced_to_surface else "center",
109-
triangulation=getattr(
110-
self.params.active_cells.topography_object, "cells", None
111-
),
125+
triangulation=cells,
112126
)
113127

114128
active_cells = (mesh.permutation @ active_cells).astype(bool)

simpeg_drivers/utils/utils.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
from collections.abc import Sequence
1415
from copy import deepcopy
1516
from typing import TYPE_CHECKING
1617

@@ -46,6 +47,60 @@
4647
from simpeg_drivers.driver import InversionDriver
4748

4849

50+
def octree_extents(octree: Octree) -> np.ndarray:
51+
"""
52+
Get the true extents of an octree (min/max of the perimeter).
53+
54+
The octree.extents property returns min/max of the centroids
55+
56+
:param octree: Octree mesh object.
57+
58+
:returns: Array of [xmin, xmax, ymin, ymax].
59+
"""
60+
61+
origin = np.array(list(octree.origin.tolist()))
62+
span = np.array(
63+
[
64+
getattr(octree, f"{axis}_cell_size") * getattr(octree, f"{axis}_count")
65+
for axis in "uvw"
66+
]
67+
)
68+
69+
return np.stack([origin, origin + span]).flatten(order="F")
70+
71+
72+
def mask_vertices_and_cells(
73+
extent: Sequence, vertices: np.ndarray, cells: np.ndarray | None
74+
) -> tuple[np.ndarray, np.ndarray]:
75+
"""
76+
Mask vertices and remove cells whose vertices are all outside the extent.
77+
78+
:param extent: Array-like object of [xmin, xmax, ymin, ymax].
79+
:param vertices: Array of shape (n_vertices, 3) containing the x, y, z coordinates.
80+
:param cells: Array of shape (n_cells, 3) containing the indices of the vertices
81+
that make up each cell.
82+
"""
83+
84+
vertex_mask = (
85+
(vertices[:, 0] >= extent[0])
86+
& (vertices[:, 0] <= extent[1])
87+
& (vertices[:, 1] >= extent[2])
88+
& (vertices[:, 1] <= extent[3])
89+
)
90+
if cells is None:
91+
return vertices[vertex_mask], None
92+
93+
cell_mask = np.any(vertex_mask[cells], axis=1)
94+
vertex_mask = np.zeros_like(vertex_mask, dtype=bool)
95+
vertex_mask[cells[cell_mask].flatten()] = True
96+
97+
new_cells = cells.copy()[cell_mask]
98+
cell_map = np.arange(len(vertices))[vertex_mask]
99+
new_cells = np.searchsorted(cell_map, new_cells)
100+
101+
return vertices[vertex_mask], new_cells
102+
103+
49104
def calculate_2D_trend(
50105
points: np.ndarray, values: np.ndarray, order: int = 0, method: str = "all"
51106
):
@@ -500,7 +555,7 @@ def active_from_xyz(
500555
raise ValueError("'grid_reference' must be one of 'center', 'top', or 'bottom'")
501556

502557
# Return the active cell array
503-
return mask_under_horizon(locations, topo, triangulation=triangulation)
558+
return mask_under_horizon(locations, horizon=topo, triangulation=triangulation)
504559

505560

506561
def truncate_locs_depths(locs: np.ndarray, depth_core: float) -> np.ndarray:

tests/utils_test.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
2+
# Copyright (c) 2026 Mira Geoscience Ltd. '
3+
# '
4+
# This file is part of simpeg-drivers package. '
5+
# '
6+
# simpeg-drivers is distributed under the terms and conditions of the MIT License '
7+
# (see LICENSE file at the root of this source code package). '
8+
# '
9+
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
10+
11+
import numpy as np
12+
from geoh5py import Workspace
13+
from geoh5py.objects import Octree, Points
14+
from grid_apps.octree_creation.driver import OctreeDriver
15+
from grid_apps.octree_creation.options import OctreeOptions, RefinementOptions
16+
17+
from simpeg_drivers.utils.utils import mask_vertices_and_cells, octree_extents
18+
19+
20+
def test_octree_extents(tmp_path):
21+
with Workspace(tmp_path / "test.geoh5") as ws:
22+
X, Y = np.meshgrid(np.linspace(0, 1000, 51), np.linspace(0, 1000, 51))
23+
Z = np.zeros_like(X)
24+
vertices = np.column_stack([X.ravel(), Y.ravel(), Z.ravel()])
25+
pts = Points.create(ws, name="points", vertices=vertices)
26+
options = OctreeOptions(
27+
geoh5=ws,
28+
objects=pts,
29+
refinements=[
30+
RefinementOptions(
31+
refinement_object=pts, levels=[4, 2], horizon=False, distance=100
32+
),
33+
],
34+
)
35+
octree = OctreeDriver.octree_from_params(options)
36+
37+
extents = octree_extents(octree)
38+
assert np.allclose(extents, [-1112.5, 2087.5, -1112.5, 2087.5, -1062.5, 537.5])
39+
40+
41+
def test_mask_vertices_and_cells():
42+
X, Y = np.meshgrid(np.arange(3), np.arange(3))
43+
Z = np.zeros_like(X)
44+
vertices = np.column_stack([X.ravel(), Y.ravel(), Z.ravel()])
45+
cells = np.array(
46+
[
47+
[0, 1, 3],
48+
[3, 1, 4],
49+
[1, 2, 4],
50+
[4, 2, 5],
51+
[3, 4, 6],
52+
[6, 4, 7],
53+
[4, 5, 7],
54+
[7, 5, 8],
55+
]
56+
)
57+
extent = [0.5, 2, 0, 2, 0, 1]
58+
masked_vertices, masked_cells = mask_vertices_and_cells(extent, vertices, cells)
59+
assert len(masked_vertices) == len(vertices)
60+
assert len(masked_cells) == len(cells)
61+
extent = [1.5, 2, 0, 2, 0, 1]
62+
masked_vertices, masked_cells = mask_vertices_and_cells(extent, vertices, cells)
63+
assert len(masked_vertices) == 6
64+
assert len(masked_cells) == 4

0 commit comments

Comments
 (0)