Skip to content

Commit 61f3311

Browse files
committed
tests passing
1 parent a689684 commit 61f3311

2 files changed

Lines changed: 75 additions & 15 deletions

File tree

simpeg_drivers/components/topography.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
floating_active,
3939
get_containing_cells,
4040
get_neighbouring_cells,
41+
mask_vertices_and_cells,
42+
octree_extents,
4143
)
4244

4345

@@ -100,25 +102,26 @@ def active_cells(self, mesh: InversionMesh, data: InversionData) -> np.ndarray:
100102
active_cells = InversionModel.obj_2_mesh(
101103
self.params.active_cells.active_model, mesh.entity
102104
)
105+
103106
else:
104-
topography = self.params.active_cells.topography_object
105-
locations = getattr(topography, "centroids", None) or self.locations
106-
xmin, xmax, ymin, ymax, zmin, zmax = mesh.entity.extent.ravel(order="F")
107-
mask = (
108-
(locations[:, 0] > xmin)
109-
& (locations[:, 0] < xmax)
110-
& (locations[:, 1] > ymin)
111-
& (locations[:, 1] < ymax)
112-
& (locations[:, 2] > zmin)
113-
& (locations[:, 2] < zmax)
114-
)
107+
if any(k in self.params.inversion_type for k in ["2d", "p3d"]):
108+
vertices = self.locations
109+
cells = getattr(
110+
self.params.active_cells.topography_object, "cells", None
111+
)
112+
else:
113+
extent = octree_extents(mesh.entity)
114+
vertices, cells = mask_vertices_and_cells(
115+
extent.ravel(order="F"),
116+
self.locations,
117+
getattr(self.params.active_cells.topography_object, "cells", None),
118+
)
115119

116-
cells = getattr(topography, "cells", None)
117120
active_cells = active_from_xyz(
118121
mesh.entity,
119-
locations,
122+
vertices,
120123
grid_reference="bottom" if forced_to_surface else "center",
121-
triangulation=cells[mask, :] or None,
124+
triangulation=cells,
122125
)
123126

124127
active_cells = (mesh.permutation @ active_cells).astype(bool)

simpeg_drivers/utils/utils.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
from collections.abc import Sequence
1415
from copy import deepcopy
1516
from typing import TYPE_CHECKING
1617

@@ -43,6 +44,62 @@
4344
from simpeg_drivers.driver import InversionDriver
4445

4546

47+
def octree_extents(octree: Octree) -> np.ndarray:
48+
"""
49+
Get the true extents of an octree (min/max of the perimeter).
50+
51+
The octree.extents property returns min/max of the centroids
52+
53+
:param octree: Octree mesh object.
54+
55+
:returns: Array of [xmin, xmax, ymin, ymax].
56+
"""
57+
58+
def half_cell(axis):
59+
return (
60+
getattr(octree, f"{axis}_cell_size") * octree.octree_cells["NCells"]
61+
) / 2
62+
63+
xmin = (octree.centroids[:, 0] - half_cell("u")).min()
64+
xmax = (octree.centroids[:, 0] + half_cell("u")).max()
65+
ymin = (octree.centroids[:, 1] - half_cell("v")).min()
66+
ymax = (octree.centroids[:, 1] + half_cell("v")).max()
67+
68+
return np.array([xmin, xmax, ymin, ymax])
69+
70+
71+
def mask_vertices_and_cells(
72+
extent: Sequence, vertices: np.ndarray, cells: np.ndarray | None
73+
) -> tuple[np.ndarray, np.ndarray]:
74+
"""
75+
Mask vertices and remove cells whose vertices are all outside the extent.
76+
77+
:param extent: Array-like object of [xmin, xmax, ymin, ymax].
78+
:param vertices: Array of shape (n_vertices, 3) containing the x, y, z coordinates.
79+
:param cells: Array of shape (n_cells, 3) containing the indices of the vertices
80+
that make up each cell.
81+
"""
82+
83+
vertex_mask = (
84+
(vertices[:, 0] >= extent[0])
85+
& (vertices[:, 0] <= extent[1])
86+
& (vertices[:, 1] >= extent[2])
87+
& (vertices[:, 1] <= extent[3])
88+
)
89+
if cells is None:
90+
return vertices[vertex_mask], None
91+
92+
cell_mask = np.any(vertex_mask[cells], axis=1)
93+
vertex_mask = np.zeros_like(vertex_mask, dtype=bool)
94+
vertex_mask[cells[cell_mask].flatten()] = True
95+
96+
new_cells = cells.copy()[cell_mask]
97+
cell_map = np.arange(len(vertices))[vertex_mask]
98+
new_cells = np.searchsorted(cell_map, new_cells)
99+
100+
return vertices[vertex_mask], new_cells
101+
102+
46103
def calculate_2D_trend(
47104
points: np.ndarray, values: np.ndarray, order: int = 0, method: str = "all"
48105
):
@@ -495,7 +552,7 @@ def active_from_xyz(
495552
raise ValueError("'grid_reference' must be one of 'center', 'top', or 'bottom'")
496553

497554
# Return the active cell array
498-
return mask_under_horizon(locations, topo, triangulation=triangulation)
555+
return mask_under_horizon(locations, horizon=topo, triangulation=triangulation)
499556

500557

501558
def truncate_locs_depths(locs: np.ndarray, depth_core: float) -> np.ndarray:

0 commit comments

Comments
 (0)