Skip to content

Commit bc24eaf

Browse files
committed
Remove duplicate octree_extent function. Update extent for ndarray
1 parent 898927a commit bc24eaf

3 files changed

Lines changed: 9 additions & 63 deletions

File tree

simpeg_drivers/components/topography.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
get_containing_cells,
4040
get_neighbouring_cells,
4141
mask_vertices_and_cells,
42-
octree_extents,
4342
)
4443

4544

@@ -113,9 +112,8 @@ def active_cells(self, mesh: InversionMesh, data: InversionData) -> np.ndarray:
113112
self.params.active_cells.topography_object, "cells", None
114113
)
115114
else:
116-
extent = octree_extents(mesh.entity)[:4]
117115
vertices, cells = mask_vertices_and_cells(
118-
extent.ravel(order="F"),
116+
mesh.entity.extent,
119117
self.locations,
120118
getattr(self.params.active_cells.topography_object, "cells", None),
121119
)

simpeg_drivers/utils/utils.py

Lines changed: 5 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from __future__ import annotations
1313

1414
import multiprocessing
15-
from collections.abc import Sequence
1615
from copy import deepcopy
1716
from pathlib import Path
1817
from typing import TYPE_CHECKING
@@ -34,7 +33,7 @@
3433
)
3534
from geoh5py.objects.surveys.electromagnetics.base import LargeLoopGroundEMSurvey
3635
from geoh5py.shared import INTEGER_NDV
37-
from geoh5py.shared.utils import fetch_active_workspace, stringify
36+
from geoh5py.shared.utils import fetch_active_workspace, mask_by_extent, stringify
3837
from geoh5py.ui_json import InputFile
3938
from grid_apps.utils import octree_2_treemesh
4039
from scipy.interpolate import interp1d
@@ -48,46 +47,20 @@
4847
from simpeg_drivers.driver import InversionDriver
4948

5049

51-
def octree_extents(octree: Octree) -> np.ndarray:
52-
"""
53-
Get the true extents of an octree (min/max of the perimeter).
54-
55-
The octree.extents property returns min/max of the centroids
56-
57-
:param octree: Octree mesh object.
58-
59-
:returns: Array of [xmin, xmax, ymin, ymax].
60-
"""
61-
62-
origin = np.array(list(octree.origin.tolist()))
63-
span = np.array(
64-
[
65-
getattr(octree, f"{axis}_cell_size") * getattr(octree, f"{axis}_count")
66-
for axis in "uvw"
67-
]
68-
)
69-
70-
return np.stack([origin, origin + span]).flatten(order="F")
71-
72-
7350
def mask_vertices_and_cells(
74-
extent: Sequence, vertices: np.ndarray, cells: np.ndarray | None
51+
extent: np.ndarray, vertices: np.ndarray, cells: np.ndarray | None
7552
) -> tuple[np.ndarray, np.ndarray]:
7653
"""
7754
Mask vertices and remove cells whose vertices are all outside the extent.
7855
79-
:param extent: Array-like object of [xmin, xmax, ymin, ymax].
56+
:param extent: Array-like object of [[xmin, ymin], [xmax, ymax]].
8057
:param vertices: Array of shape (n_vertices, 3) containing the x, y, z coordinates.
8158
:param cells: Array of shape (n_cells, 3) containing the indices of the vertices
8259
that make up each cell.
8360
"""
8461

85-
vertex_mask = (
86-
(vertices[:, 0] >= extent[0])
87-
& (vertices[:, 0] <= extent[1])
88-
& (vertices[:, 1] >= extent[2])
89-
& (vertices[:, 1] <= extent[3])
90-
)
62+
vertex_mask = mask_by_extent(vertices, extent=extent)
63+
9164
if cells is None:
9265
return vertices[vertex_mask], None
9366

tests/utils_test.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,33 +9,8 @@
99
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
1010

1111
import numpy as np
12-
from geoh5py import Workspace
13-
from geoh5py.objects import Octree, Points
14-
from grid_apps.octree_creation.driver import OctreeDriver
15-
from grid_apps.octree_creation.options import OctreeOptions, RefinementOptions
1612

17-
from simpeg_drivers.utils.utils import mask_vertices_and_cells, octree_extents
18-
19-
20-
def test_octree_extents(tmp_path):
21-
with Workspace(tmp_path / "test.geoh5") as ws:
22-
X, Y = np.meshgrid(np.linspace(0, 1000, 51), np.linspace(0, 1000, 51))
23-
Z = np.zeros_like(X)
24-
vertices = np.column_stack([X.ravel(), Y.ravel(), Z.ravel()])
25-
pts = Points.create(ws, name="points", vertices=vertices)
26-
options = OctreeOptions(
27-
geoh5=ws,
28-
objects=pts,
29-
refinements=[
30-
RefinementOptions(
31-
refinement_object=pts, levels=[4, 2], horizon=False, distance=100
32-
),
33-
],
34-
)
35-
octree = OctreeDriver.octree_from_params(options)
36-
37-
extents = octree_extents(octree)
38-
assert np.allclose(extents, [-1112.5, 2087.5, -1112.5, 2087.5, -1062.5, 537.5])
13+
from simpeg_drivers.utils.utils import mask_vertices_and_cells
3914

4015

4116
def test_mask_vertices_and_cells():
@@ -54,11 +29,11 @@ def test_mask_vertices_and_cells():
5429
[7, 5, 8],
5530
]
5631
)
57-
extent = [0.5, 2, 0, 2, 0, 1]
32+
extent = np.vstack([[0.5, 0, 0], [2, 2, 1]])
5833
masked_vertices, masked_cells = mask_vertices_and_cells(extent, vertices, cells)
5934
assert len(masked_vertices) == len(vertices)
6035
assert len(masked_cells) == len(cells)
61-
extent = [1.5, 2, 0, 2, 0, 1]
36+
extent = np.vstack([[1.5, 0, 0], [2, 2, 1]])
6237
masked_vertices, masked_cells = mask_vertices_and_cells(extent, vertices, cells)
6338
assert len(masked_vertices) == 6
6439
assert len(masked_cells) == 4

0 commit comments

Comments
 (0)