Skip to content

Commit 0dfa19b

Browse files
committed
refactor neighbor/corner finding and add unit test
1 parent b47a375 commit 0dfa19b

2 files changed

Lines changed: 145 additions & 45 deletions

File tree

simpeg_drivers/utils/regularization.py

Lines changed: 53 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -38,52 +38,32 @@ def cell_neighbors_along_axis(mesh: TreeMesh, axis: str) -> np.ndarray:
3838
return np.sort(stencil_indices, axis=1)
3939

4040

41-
def cell_neighbors(mesh: TreeMesh) -> np.ndarray:
42-
"""Find all cell neighbors in a TreeMesh."""
43-
44-
x_neighbors = cell_neighbors_along_axis(mesh, "x")
45-
x_neighbors_backward = np.fliplr(x_neighbors)
46-
y_neighbors = cell_neighbors_along_axis(mesh, "y")
47-
y_neighbors_backward = np.fliplr(y_neighbors)
48-
max_index = np.max([x_neighbors.max(), y_neighbors.max()])
49-
if mesh.dim == 3:
50-
z_neighbors = cell_neighbors_along_axis(mesh, "z")
51-
z_neighbors_backward = np.fliplr(z_neighbors)
52-
max_index = np.max([max_index, z_neighbors.max()])
53-
41+
def collect_all_neighbors(neighbors, neighbors_backwards, corners, corners_backwards):
5442
all_neighbors = [] # Store
55-
x_adjacent = np.ones(max_index + 1, dtype="int") * -1
56-
y_adjacent = np.ones(max_index + 1, dtype="int") * -1
57-
x_adjacent_backward = np.ones(max_index + 1, dtype="int") * -1
58-
y_adjacent_backward = np.ones(max_index + 1, dtype="int") * -1
59-
60-
x_adjacent[y_neighbors[:, 0]] = y_neighbors[:, 1]
61-
y_adjacent[x_neighbors[:, 1]] = x_neighbors[:, 0]
6243

63-
x_adjacent_backward[y_neighbors_backward[:, 0]] = y_neighbors_backward[:, 1]
64-
y_adjacent_backward[x_neighbors_backward[:, 1]] = x_neighbors_backward[:, 0]
44+
all_neighbors += [neighbors[0]]
45+
all_neighbors += [neighbors[1]]
6546

66-
all_neighbors += [x_neighbors]
67-
all_neighbors += [y_neighbors]
47+
all_neighbors += [np.c_[neighbors[0][:, 0], corners[0][neighbors[0][:, 1]]]]
48+
all_neighbors += [np.c_[neighbors[0][:, 1], corners[0][neighbors[0][:, 0]]]]
6849

69-
all_neighbors += [np.c_[x_neighbors[:, 0], x_adjacent[x_neighbors[:, 1]]]]
70-
all_neighbors += [np.c_[x_neighbors[:, 1], x_adjacent[x_neighbors[:, 0]]]]
71-
72-
all_neighbors += [np.c_[y_adjacent[y_neighbors[:, 0]], y_neighbors[:, 1]]]
73-
all_neighbors += [np.c_[y_adjacent[y_neighbors[:, 1]], y_neighbors[:, 0]]]
50+
all_neighbors += [np.c_[corners[1][neighbors[1][:, 0]], neighbors[1][:, 1]]]
51+
all_neighbors += [np.c_[corners[1][neighbors[1][:, 1]], neighbors[1][:, 0]]]
7452

7553
# Repeat backward for Treemesh
76-
all_neighbors += [x_neighbors_backward]
77-
all_neighbors += [y_neighbors_backward]
54+
all_neighbors += [neighbors_backwards[0]]
55+
all_neighbors += [neighbors_backwards[1]]
7856

7957
all_neighbors += [
8058
np.c_[
81-
x_neighbors_backward[:, 0], x_adjacent_backward[x_neighbors_backward[:, 1]]
59+
neighbors_backwards[0][:, 0],
60+
corners_backwards[0][neighbors_backwards[0][:, 1]],
8261
]
8362
]
8463
all_neighbors += [
8564
np.c_[
86-
x_neighbors_backward[:, 1], x_adjacent_backward[x_neighbors_backward[:, 0]]
65+
neighbors_backwards[0][:, 1],
66+
corners_backwards[0][neighbors_backwards[0][:, 0]],
8767
]
8868
]
8969

@@ -97,25 +77,20 @@ def cell_neighbors(mesh: TreeMesh) -> np.ndarray:
9777
]
9878

9979
# Use all the neighbours on the xy plane to find neighbours in z
100-
if mesh.dim == 3:
80+
if len(neighbors) == 3:
10181
all_neighbors_z = []
102-
z_adjacent = np.ones(max_index + 1, dtype="int") * -1
103-
z_adjacent_backward = np.ones(max_index + 1, dtype="int") * -1
104-
105-
z_adjacent[z_neighbors[:, 0]] = z_neighbors[:, 1]
106-
z_adjacent_backward[z_neighbors_backward[:, 0]] = z_neighbors_backward[:, 1]
10782

108-
all_neighbors_z += [z_neighbors]
109-
all_neighbors_z += [z_neighbors_backward]
83+
all_neighbors_z += [neighbors[2]]
84+
all_neighbors_z += [neighbors_backwards[2]]
11085

111-
all_neighbors_z += [np.c_[all_neighbors[:, 0], z_adjacent[all_neighbors[:, 1]]]]
112-
all_neighbors_z += [np.c_[all_neighbors[:, 1], z_adjacent[all_neighbors[:, 0]]]]
86+
all_neighbors_z += [np.c_[all_neighbors[:, 0], corners[2][all_neighbors[:, 1]]]]
87+
all_neighbors_z += [np.c_[all_neighbors[:, 1], corners[2][all_neighbors[:, 0]]]]
11388

11489
all_neighbors_z += [
115-
np.c_[all_neighbors[:, 0], z_adjacent_backward[all_neighbors[:, 1]]]
90+
np.c_[all_neighbors[:, 0], corners_backwards[2][all_neighbors[:, 1]]]
11691
]
11792
all_neighbors_z += [
118-
np.c_[all_neighbors[:, 1], z_adjacent_backward[all_neighbors[:, 0]]]
93+
np.c_[all_neighbors[:, 1], corners_backwards[2][all_neighbors[:, 0]]]
11994
]
12095

12196
# Stack all and keep only unique pairs
@@ -130,6 +105,39 @@ def cell_neighbors(mesh: TreeMesh) -> np.ndarray:
130105
return all_neighbors
131106

132107

108+
def cell_adjacent(neighbors: list[np.ndarray]) -> list[np.ndarray]:
109+
"""Find all cell corners from cell neighbor array."""
110+
111+
dim = len(neighbors)
112+
max_index = np.max(neighbors)
113+
corners = -1 * np.ones((dim, max_index + 1), dtype="int")
114+
115+
corners[0, neighbors[1][:, 0]] = neighbors[1][:, 1]
116+
corners[1, neighbors[0][:, 1]] = neighbors[0][:, 0]
117+
if dim == 3:
118+
corners[2, neighbors[2][:, 0]] = neighbors[2][:, 1]
119+
120+
return [np.array(k) for k in corners.tolist()]
121+
122+
123+
def cell_neighbors(mesh: TreeMesh) -> np.ndarray:
124+
"""Find all cell neighbors in a TreeMesh."""
125+
126+
neighbors = []
127+
neighbors.append(cell_neighbors_along_axis(mesh, "x"))
128+
neighbors.append(cell_neighbors_along_axis(mesh, "y"))
129+
if mesh.dim == 3:
130+
neighbors.append(cell_neighbors_along_axis(mesh, "z"))
131+
132+
neighbors_backwards = [np.fliplr(k) for k in neighbors]
133+
corners = cell_adjacent(neighbors)
134+
corners_backwards = cell_adjacent(neighbors_backwards)
135+
136+
return collect_all_neighbors(
137+
neighbors, neighbors_backwards, corners, corners_backwards
138+
)
139+
140+
133141
def rotate_xz_2d(mesh: TreeMesh, phi: np.ndarray) -> ssp.csr_matrix:
134142
"""
135143
Create a 2d ellipsoidal rotation matrix for the xz plane.

tests/utils_regularization_test.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
2+
# Copyright (c) 2025 Mira Geoscience Ltd. '
3+
# '
4+
# This file is part of simpeg-drivers package. '
5+
# '
6+
# simpeg-drivers is distributed under the terms and conditions of the MIT License '
7+
# (see LICENSE file at the root of this source code package). '
8+
# '
9+
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
10+
11+
import numpy as np
12+
from discretize import TreeMesh
13+
14+
from simpeg_drivers.utils.regularization import (
15+
cell_adjacent,
16+
cell_neighbors_along_axis,
17+
collect_all_neighbors,
18+
)
19+
20+
21+
def get_mesh():
22+
mesh = TreeMesh(h=[[10.0] * 4, [10.0] * 4, [10.0] * 4], diagonal_balance=False)
23+
mesh.refine(2)
24+
return mesh
25+
26+
27+
def test_cell_neighbors_along_axis():
28+
mesh = get_mesh()
29+
centers = mesh.cell_centers
30+
neighbors = cell_neighbors_along_axis(mesh, "x")
31+
assert np.allclose(centers[7], [15.0, 15.0, 15.0])
32+
assert np.allclose(
33+
centers[neighbors[neighbors[:, 0] == 7][0][1]], [25.0, 15.0, 15.0]
34+
)
35+
assert np.allclose(
36+
centers[neighbors[neighbors[:, 1] == 7][0][0]], [5.0, 15.0, 15.0]
37+
)
38+
neighbors = cell_neighbors_along_axis(mesh, "y")
39+
assert np.allclose(
40+
centers[neighbors[neighbors[:, 0] == 7][0][1]], [15.0, 25.0, 15.0]
41+
)
42+
assert np.allclose(
43+
centers[neighbors[neighbors[:, 1] == 7][0][0]], [15.0, 5.0, 15.0]
44+
)
45+
neighbors = cell_neighbors_along_axis(mesh, "z")
46+
assert np.allclose(
47+
centers[neighbors[neighbors[:, 0] == 7][0][1]], [15.0, 15.0, 25.0]
48+
)
49+
assert np.allclose(
50+
centers[neighbors[neighbors[:, 1] == 7][0][0]], [15.0, 15.0, 5.0]
51+
)
52+
53+
54+
def test_collect_all_neighbors():
55+
mesh = get_mesh()
56+
centers = mesh.cell_centers
57+
neighbors = [cell_neighbors_along_axis(mesh, k) for k in "xyz"]
58+
neighbors_bck = [np.fliplr(k) for k in neighbors]
59+
corners = cell_adjacent(neighbors)
60+
corners_bck = cell_adjacent(neighbors_bck)
61+
all_neighbors = collect_all_neighbors(
62+
neighbors, neighbors_bck, corners, corners_bck
63+
)
64+
assert np.allclose(centers[7], [15.0, 15.0, 15.0])
65+
neighbor_centers = centers[all_neighbors[all_neighbors[:, 0] == 7][:, 1]].tolist()
66+
assert [5, 5, 5] in neighbor_centers
67+
assert [15, 5, 5] in neighbor_centers
68+
assert [25, 5, 5] in neighbor_centers
69+
assert [5, 15, 5] in neighbor_centers
70+
assert [15, 15, 5] in neighbor_centers
71+
assert [25, 15, 5] in neighbor_centers
72+
assert [5, 25, 5] in neighbor_centers
73+
assert [15, 25, 5] in neighbor_centers
74+
assert [25, 25, 5] in neighbor_centers
75+
assert [5, 5, 15] in neighbor_centers
76+
assert [15, 5, 15] in neighbor_centers
77+
assert [25, 5, 15] in neighbor_centers
78+
assert [5, 15, 15] in neighbor_centers
79+
assert [25, 15, 15] in neighbor_centers
80+
assert [5, 25, 15] in neighbor_centers
81+
assert [15, 25, 15] in neighbor_centers
82+
assert [25, 25, 15] in neighbor_centers
83+
assert [5, 5, 25] in neighbor_centers
84+
assert [15, 5, 25] in neighbor_centers
85+
assert [25, 5, 25] in neighbor_centers
86+
assert [5, 15, 25] in neighbor_centers
87+
assert [15, 15, 25] in neighbor_centers
88+
assert [25, 15, 25] in neighbor_centers
89+
assert [5, 25, 25] in neighbor_centers
90+
assert [15, 25, 25] in neighbor_centers
91+
assert [25, 25, 25] in neighbor_centers
92+
assert [15, 15, 15] not in neighbor_centers

0 commit comments

Comments
 (0)