Skip to content

Commit d1746f2

Browse files
committed
Merge branch 'GEOPY-2075' into GEOPY-2075_c
2 parents 27934f8 + 6271153 commit d1746f2

2 files changed

Lines changed: 161 additions & 44 deletions

File tree

simpeg_drivers/utils/regularization.py

Lines changed: 69 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -39,52 +39,45 @@ def cell_neighbors_along_axis(mesh: TreeMesh, axis: str) -> np.ndarray:
3939
return np.sort(stencil_indices, axis=1)
4040

4141

42-
def cell_neighbors(mesh: TreeMesh) -> np.ndarray:
43-
"""Find all cell neighbors in a TreeMesh."""
44-
45-
x_neighbors = cell_neighbors_along_axis(mesh, "x")
46-
x_neighbors_backward = np.fliplr(x_neighbors)
47-
y_neighbors = cell_neighbors_along_axis(mesh, "y")
48-
y_neighbors_backward = np.fliplr(y_neighbors)
49-
max_index = np.max([x_neighbors.max(), y_neighbors.max()])
50-
if mesh.dim == 3:
51-
z_neighbors = cell_neighbors_along_axis(mesh, "z")
52-
z_neighbors_backward = np.fliplr(z_neighbors)
53-
max_index = np.max([max_index, z_neighbors.max()])
42+
def collect_all_neighbors(
43+
neighbors: list[np.ndarray],
44+
neighbors_backwards: list[np.ndarray],
45+
adjacent: list[np.ndarray],
46+
adjacent_backwards: list[np.ndarray],
47+
) -> np.ndarray:
48+
"""
49+
Collect all neighbors for cells in the mesh.
5450
51+
:param neighbors: Direct neighbors in each principle axes.
52+
:param neighbors_backwards: Direct neighbors in reverse order.
53+
:param adjacent: Adjacent neighbors (corners).
54+
:param adjacent_backwards: Adjacent neighbors in reverse order.
55+
"""
5556
all_neighbors = [] # Store
56-
x_adjacent = np.ones(max_index + 1, dtype="int") * -1
57-
y_adjacent = np.ones(max_index + 1, dtype="int") * -1
58-
x_adjacent_backward = np.ones(max_index + 1, dtype="int") * -1
59-
y_adjacent_backward = np.ones(max_index + 1, dtype="int") * -1
60-
61-
x_adjacent[y_neighbors[:, 0]] = y_neighbors[:, 1]
62-
y_adjacent[x_neighbors[:, 1]] = x_neighbors[:, 0]
6357

64-
x_adjacent_backward[y_neighbors_backward[:, 0]] = y_neighbors_backward[:, 1]
65-
y_adjacent_backward[x_neighbors_backward[:, 1]] = x_neighbors_backward[:, 0]
58+
all_neighbors += [neighbors[0]]
59+
all_neighbors += [neighbors[1]]
6660

67-
all_neighbors += [x_neighbors]
68-
all_neighbors += [y_neighbors]
61+
all_neighbors += [np.c_[neighbors[0][:, 0], adjacent[0][neighbors[0][:, 1]]]]
62+
all_neighbors += [np.c_[neighbors[0][:, 1], adjacent[0][neighbors[0][:, 0]]]]
6963

70-
all_neighbors += [np.c_[x_neighbors[:, 0], x_adjacent[x_neighbors[:, 1]]]]
71-
all_neighbors += [np.c_[x_neighbors[:, 1], x_adjacent[x_neighbors[:, 0]]]]
72-
73-
all_neighbors += [np.c_[y_adjacent[y_neighbors[:, 0]], y_neighbors[:, 1]]]
74-
all_neighbors += [np.c_[y_adjacent[y_neighbors[:, 1]], y_neighbors[:, 0]]]
64+
all_neighbors += [np.c_[adjacent[1][neighbors[1][:, 0]], neighbors[1][:, 1]]]
65+
all_neighbors += [np.c_[adjacent[1][neighbors[1][:, 1]], neighbors[1][:, 0]]]
7566

7667
# Repeat backward for Treemesh
77-
all_neighbors += [x_neighbors_backward]
78-
all_neighbors += [y_neighbors_backward]
68+
all_neighbors += [neighbors_backwards[0]]
69+
all_neighbors += [neighbors_backwards[1]]
7970

8071
all_neighbors += [
8172
np.c_[
82-
x_neighbors_backward[:, 0], x_adjacent_backward[x_neighbors_backward[:, 1]]
73+
neighbors_backwards[0][:, 0],
74+
adjacent_backwards[0][neighbors_backwards[0][:, 1]],
8375
]
8476
]
8577
all_neighbors += [
8678
np.c_[
87-
x_neighbors_backward[:, 1], x_adjacent_backward[x_neighbors_backward[:, 0]]
79+
neighbors_backwards[0][:, 1],
80+
adjacent_backwards[0][neighbors_backwards[0][:, 0]],
8881
]
8982
]
9083

@@ -98,25 +91,24 @@ def cell_neighbors(mesh: TreeMesh) -> np.ndarray:
9891
]
9992

10093
# Use all the neighbours on the xy plane to find neighbours in z
101-
if mesh.dim == 3:
94+
if len(neighbors) == 3:
10295
all_neighbors_z = []
103-
z_adjacent = np.ones(max_index + 1, dtype="int") * -1
104-
z_adjacent_backward = np.ones(max_index + 1, dtype="int") * -1
10596

106-
z_adjacent[z_neighbors[:, 0]] = z_neighbors[:, 1]
107-
z_adjacent_backward[z_neighbors_backward[:, 0]] = z_neighbors_backward[:, 1]
97+
all_neighbors_z += [neighbors[2]]
98+
all_neighbors_z += [neighbors_backwards[2]]
10899

109-
all_neighbors_z += [z_neighbors]
110-
all_neighbors_z += [z_neighbors_backward]
111-
112-
all_neighbors_z += [np.c_[all_neighbors[:, 0], z_adjacent[all_neighbors[:, 1]]]]
113-
all_neighbors_z += [np.c_[all_neighbors[:, 1], z_adjacent[all_neighbors[:, 0]]]]
100+
all_neighbors_z += [
101+
np.c_[all_neighbors[:, 0], adjacent[2][all_neighbors[:, 1]]]
102+
]
103+
all_neighbors_z += [
104+
np.c_[all_neighbors[:, 1], adjacent[2][all_neighbors[:, 0]]]
105+
]
114106

115107
all_neighbors_z += [
116-
np.c_[all_neighbors[:, 0], z_adjacent_backward[all_neighbors[:, 1]]]
108+
np.c_[all_neighbors[:, 0], adjacent_backwards[2][all_neighbors[:, 1]]]
117109
]
118110
all_neighbors_z += [
119-
np.c_[all_neighbors[:, 1], z_adjacent_backward[all_neighbors[:, 0]]]
111+
np.c_[all_neighbors[:, 1], adjacent_backwards[2][all_neighbors[:, 0]]]
120112
]
121113

122114
# Stack all and keep only unique pairs
@@ -131,6 +123,39 @@ def cell_neighbors(mesh: TreeMesh) -> np.ndarray:
131123
return all_neighbors
132124

133125

126+
def cell_adjacent(neighbors: list[np.ndarray]) -> list[np.ndarray]:
127+
"""Find all adjacent cells (corners) from cell neighbor array."""
128+
129+
dim = len(neighbors)
130+
max_index = np.max(neighbors)
131+
corners = -1 * np.ones((dim, max_index + 1), dtype="int")
132+
133+
corners[0, neighbors[1][:, 0]] = neighbors[1][:, 1]
134+
corners[1, neighbors[0][:, 1]] = neighbors[0][:, 0]
135+
if dim == 3:
136+
corners[2, neighbors[2][:, 0]] = neighbors[2][:, 1]
137+
138+
return [np.array(k) for k in corners.tolist()]
139+
140+
141+
def cell_neighbors(mesh: TreeMesh) -> np.ndarray:
142+
"""Find all cell neighbors in a TreeMesh."""
143+
144+
neighbors = []
145+
neighbors.append(cell_neighbors_along_axis(mesh, "x"))
146+
neighbors.append(cell_neighbors_along_axis(mesh, "y"))
147+
if mesh.dim == 3:
148+
neighbors.append(cell_neighbors_along_axis(mesh, "z"))
149+
150+
neighbors_backwards = [np.fliplr(k) for k in neighbors]
151+
corners = cell_adjacent(neighbors)
152+
corners_backwards = cell_adjacent(neighbors_backwards)
153+
154+
return collect_all_neighbors(
155+
neighbors, neighbors_backwards, corners, corners_backwards
156+
)
157+
158+
134159
def rotate_xz_2d(mesh: TreeMesh, phi: np.ndarray) -> ssp.csr_matrix:
135160
"""
136161
Create a 2d ellipsoidal rotation matrix for the xz plane.

tests/utils_regularization_test.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
2+
# Copyright (c) 2025 Mira Geoscience Ltd. '
3+
# '
4+
# This file is part of simpeg-drivers package. '
5+
# '
6+
# simpeg-drivers is distributed under the terms and conditions of the MIT License '
7+
# (see LICENSE file at the root of this source code package). '
8+
# '
9+
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
10+
11+
import numpy as np
12+
from discretize import TreeMesh
13+
14+
from simpeg_drivers.utils.regularization import (
15+
cell_adjacent,
16+
cell_neighbors_along_axis,
17+
collect_all_neighbors,
18+
)
19+
20+
21+
def get_mesh():
22+
mesh = TreeMesh(h=[[10.0] * 4, [10.0] * 4, [10.0] * 4], diagonal_balance=False)
23+
mesh.refine(2)
24+
return mesh
25+
26+
27+
def test_cell_neighbors_along_axis():
28+
mesh = get_mesh()
29+
centers = mesh.cell_centers
30+
neighbors = cell_neighbors_along_axis(mesh, "x")
31+
assert np.allclose(centers[7], [15.0, 15.0, 15.0])
32+
assert np.allclose(
33+
centers[neighbors[neighbors[:, 0] == 7][0][1]], [25.0, 15.0, 15.0]
34+
)
35+
assert np.allclose(
36+
centers[neighbors[neighbors[:, 1] == 7][0][0]], [5.0, 15.0, 15.0]
37+
)
38+
neighbors = cell_neighbors_along_axis(mesh, "y")
39+
assert np.allclose(
40+
centers[neighbors[neighbors[:, 0] == 7][0][1]], [15.0, 25.0, 15.0]
41+
)
42+
assert np.allclose(
43+
centers[neighbors[neighbors[:, 1] == 7][0][0]], [15.0, 5.0, 15.0]
44+
)
45+
neighbors = cell_neighbors_along_axis(mesh, "z")
46+
assert np.allclose(
47+
centers[neighbors[neighbors[:, 0] == 7][0][1]], [15.0, 15.0, 25.0]
48+
)
49+
assert np.allclose(
50+
centers[neighbors[neighbors[:, 1] == 7][0][0]], [15.0, 15.0, 5.0]
51+
)
52+
53+
54+
def test_collect_all_neighbors():
55+
mesh = get_mesh()
56+
centers = mesh.cell_centers
57+
neighbors = [cell_neighbors_along_axis(mesh, k) for k in "xyz"]
58+
neighbors_bck = [np.fliplr(k) for k in neighbors]
59+
corners = cell_adjacent(neighbors)
60+
corners_bck = cell_adjacent(neighbors_bck)
61+
all_neighbors = collect_all_neighbors(
62+
neighbors, neighbors_bck, corners, corners_bck
63+
)
64+
assert np.allclose(centers[7], [15.0, 15.0, 15.0])
65+
neighbor_centers = centers[all_neighbors[all_neighbors[:, 0] == 7][:, 1]].tolist()
66+
assert [5, 5, 5] in neighbor_centers
67+
assert [15, 5, 5] in neighbor_centers
68+
assert [25, 5, 5] in neighbor_centers
69+
assert [5, 15, 5] in neighbor_centers
70+
assert [15, 15, 5] in neighbor_centers
71+
assert [25, 15, 5] in neighbor_centers
72+
assert [5, 25, 5] in neighbor_centers
73+
assert [15, 25, 5] in neighbor_centers
74+
assert [25, 25, 5] in neighbor_centers
75+
assert [5, 5, 15] in neighbor_centers
76+
assert [15, 5, 15] in neighbor_centers
77+
assert [25, 5, 15] in neighbor_centers
78+
assert [5, 15, 15] in neighbor_centers
79+
assert [25, 15, 15] in neighbor_centers
80+
assert [5, 25, 15] in neighbor_centers
81+
assert [15, 25, 15] in neighbor_centers
82+
assert [25, 25, 15] in neighbor_centers
83+
assert [5, 5, 25] in neighbor_centers
84+
assert [15, 5, 25] in neighbor_centers
85+
assert [25, 5, 25] in neighbor_centers
86+
assert [5, 15, 25] in neighbor_centers
87+
assert [15, 15, 25] in neighbor_centers
88+
assert [25, 15, 25] in neighbor_centers
89+
assert [5, 25, 25] in neighbor_centers
90+
assert [15, 25, 25] in neighbor_centers
91+
assert [25, 25, 25] in neighbor_centers
92+
assert [15, 15, 15] not in neighbor_centers

0 commit comments

Comments
 (0)