Skip to content

Commit fdf7fd0

Browse files
committed
Cleaner octree_extents implementation
1 parent 61f3311 commit fdf7fd0

3 files changed

Lines changed: 73 additions & 11 deletions

File tree

simpeg_drivers/components/topography.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def active_cells(self, mesh: InversionMesh, data: InversionData) -> np.ndarray:
110110
self.params.active_cells.topography_object, "cells", None
111111
)
112112
else:
113-
extent = octree_extents(mesh.entity)
113+
extent = octree_extents(mesh.entity)[:4]
114114
vertices, cells = mask_vertices_and_cells(
115115
extent.ravel(order="F"),
116116
self.locations,

simpeg_drivers/utils/utils.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,15 @@ def octree_extents(octree: Octree) -> np.ndarray:
5555
:returns: Array of [xmin, xmax, ymin, ymax].
5656
"""
5757

58-
def half_cell(axis):
59-
return (
60-
getattr(octree, f"{axis}_cell_size") * octree.octree_cells["NCells"]
61-
) / 2
62-
63-
xmin = (octree.centroids[:, 0] - half_cell("u")).min()
64-
xmax = (octree.centroids[:, 0] + half_cell("u")).max()
65-
ymin = (octree.centroids[:, 1] - half_cell("v")).min()
66-
ymax = (octree.centroids[:, 1] + half_cell("v")).max()
58+
origin = np.array(list(octree.origin.tolist()))
59+
span = np.array(
60+
[
61+
getattr(octree, f"{axis}_cell_size") * getattr(octree, f"{axis}_count")
62+
for axis in "uvw"
63+
]
64+
)
6765

68-
return np.array([xmin, xmax, ymin, ymax])
66+
return np.stack([origin, origin + span]).flatten(order="F")
6967

7068

7169
def mask_vertices_and_cells(

tests/utils_test.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
2+
# Copyright (c) 2026 Mira Geoscience Ltd. '
3+
# '
4+
# This file is part of simpeg-drivers package. '
5+
# '
6+
# simpeg-drivers is distributed under the terms and conditions of the MIT License '
7+
# (see LICENSE file at the root of this source code package). '
8+
# '
9+
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
10+
11+
import numpy as np
12+
from geoh5py import Workspace
13+
from geoh5py.objects import Octree, Points
14+
from grid_apps.octree_creation.driver import OctreeDriver
15+
from grid_apps.octree_creation.options import OctreeOptions, RefinementOptions
16+
17+
from simpeg_drivers.utils.utils import mask_vertices_and_cells, octree_extents
18+
19+
20+
def test_octree_extents(tmp_path):
21+
with Workspace(tmp_path / "test.geoh5") as ws:
22+
X, Y = np.meshgrid(np.linspace(0, 1000, 51), np.linspace(0, 1000, 51))
23+
Z = np.zeros_like(X)
24+
vertices = np.column_stack([X.ravel(), Y.ravel(), Z.ravel()])
25+
pts = Points.create(ws, name="points", vertices=vertices)
26+
options = OctreeOptions(
27+
geoh5=ws,
28+
objects=pts,
29+
refinements=[
30+
RefinementOptions(
31+
refinement_object=pts, levels=[4, 2], horizon=False, distance=100
32+
),
33+
],
34+
)
35+
octree = OctreeDriver.octree_from_params(options)
36+
37+
extents = octree_extents(octree)
38+
assert np.allclose(extents, [-1112.5, 2087.5, -1112.5, 2087.5])
39+
40+
41+
def test_mask_vertices_and_cells():
42+
X, Y = np.meshgrid(np.arange(3), np.arange(3))
43+
Z = np.zeros_like(X)
44+
vertices = np.column_stack([X.ravel(), Y.ravel(), Z.ravel()])
45+
cells = np.array(
46+
[
47+
[0, 1, 3],
48+
[3, 1, 4],
49+
[1, 2, 4],
50+
[4, 2, 5],
51+
[3, 4, 6],
52+
[6, 4, 7],
53+
[4, 5, 7],
54+
[7, 5, 8],
55+
]
56+
)
57+
extent = [0.5, 2, 0, 2, 0, 1]
58+
masked_vertices, masked_cells = mask_vertices_and_cells(extent, vertices, cells)
59+
assert len(masked_vertices) == len(vertices)
60+
assert len(masked_cells) == len(cells)
61+
extent = [1.5, 2, 0, 2, 0, 1]
62+
masked_vertices, masked_cells = mask_vertices_and_cells(extent, vertices, cells)
63+
assert len(masked_vertices) == 6
64+
assert len(masked_cells) == 4

0 commit comments

Comments
 (0)