@@ -39,52 +39,45 @@ def cell_neighbors_along_axis(mesh: TreeMesh, axis: str) -> np.ndarray:
3939 return np .sort (stencil_indices , axis = 1 )
4040
4141
42- def cell_neighbors (mesh : TreeMesh ) -> np .ndarray :
43- """Find all cell neighbors in a TreeMesh."""
44-
45- x_neighbors = cell_neighbors_along_axis (mesh , "x" )
46- x_neighbors_backward = np .fliplr (x_neighbors )
47- y_neighbors = cell_neighbors_along_axis (mesh , "y" )
48- y_neighbors_backward = np .fliplr (y_neighbors )
49- max_index = np .max ([x_neighbors .max (), y_neighbors .max ()])
50- if mesh .dim == 3 :
51- z_neighbors = cell_neighbors_along_axis (mesh , "z" )
52- z_neighbors_backward = np .fliplr (z_neighbors )
53- max_index = np .max ([max_index , z_neighbors .max ()])
42+ def collect_all_neighbors (
43+ neighbors : list [np .ndarray ],
44+ neighbors_backwards : list [np .ndarray ],
45+ adjacent : list [np .ndarray ],
46+ adjacent_backwards : list [np .ndarray ],
47+ ) -> np .ndarray :
48+ """
49+ Collect all neighbors for cells in the mesh.
5450
51+ :param neighbors: Direct neighbors in each principle axes.
52+ :param neighbors_backwards: Direct neighbors in reverse order.
53+ :param adjacent: Adjacent neighbors (corners).
54+ :param adjacent_backwards: Adjacent neighbors in reverse order.
55+ """
5556 all_neighbors = [] # Store
56- x_adjacent = np .ones (max_index + 1 , dtype = "int" ) * - 1
57- y_adjacent = np .ones (max_index + 1 , dtype = "int" ) * - 1
58- x_adjacent_backward = np .ones (max_index + 1 , dtype = "int" ) * - 1
59- y_adjacent_backward = np .ones (max_index + 1 , dtype = "int" ) * - 1
60-
61- x_adjacent [y_neighbors [:, 0 ]] = y_neighbors [:, 1 ]
62- y_adjacent [x_neighbors [:, 1 ]] = x_neighbors [:, 0 ]
6357
64- x_adjacent_backward [ y_neighbors_backward [:, 0 ]] = y_neighbors_backward [:, 1 ]
65- y_adjacent_backward [ x_neighbors_backward [:, 1 ]] = x_neighbors_backward [:, 0 ]
58+ all_neighbors += [ neighbors [ 0 ] ]
59+ all_neighbors += [ neighbors [ 1 ] ]
6660
67- all_neighbors += [x_neighbors ]
68- all_neighbors += [y_neighbors ]
61+ all_neighbors += [np . c_ [ neighbors [ 0 ][:, 0 ], adjacent [ 0 ][ neighbors [ 0 ][:, 1 ]]] ]
62+ all_neighbors += [np . c_ [ neighbors [ 0 ][:, 1 ], adjacent [ 0 ][ neighbors [ 0 ][:, 0 ]]] ]
6963
70- all_neighbors += [np .c_ [x_neighbors [:, 0 ], x_adjacent [x_neighbors [:, 1 ]]]]
71- all_neighbors += [np .c_ [x_neighbors [:, 1 ], x_adjacent [x_neighbors [:, 0 ]]]]
72-
73- all_neighbors += [np .c_ [y_adjacent [y_neighbors [:, 0 ]], y_neighbors [:, 1 ]]]
74- all_neighbors += [np .c_ [y_adjacent [y_neighbors [:, 1 ]], y_neighbors [:, 0 ]]]
64+ all_neighbors += [np .c_ [adjacent [1 ][neighbors [1 ][:, 0 ]], neighbors [1 ][:, 1 ]]]
65+ all_neighbors += [np .c_ [adjacent [1 ][neighbors [1 ][:, 1 ]], neighbors [1 ][:, 0 ]]]
7566
7667 # Repeat backward for Treemesh
77- all_neighbors += [x_neighbors_backward ]
78- all_neighbors += [y_neighbors_backward ]
68+ all_neighbors += [neighbors_backwards [ 0 ] ]
69+ all_neighbors += [neighbors_backwards [ 1 ] ]
7970
8071 all_neighbors += [
8172 np .c_ [
82- x_neighbors_backward [:, 0 ], x_adjacent_backward [x_neighbors_backward [:, 1 ]]
73+ neighbors_backwards [0 ][:, 0 ],
74+ adjacent_backwards [0 ][neighbors_backwards [0 ][:, 1 ]],
8375 ]
8476 ]
8577 all_neighbors += [
8678 np .c_ [
87- x_neighbors_backward [:, 1 ], x_adjacent_backward [x_neighbors_backward [:, 0 ]]
79+ neighbors_backwards [0 ][:, 1 ],
80+ adjacent_backwards [0 ][neighbors_backwards [0 ][:, 0 ]],
8881 ]
8982 ]
9083
@@ -98,25 +91,24 @@ def cell_neighbors(mesh: TreeMesh) -> np.ndarray:
9891 ]
9992
10093 # Use all the neighbours on the xy plane to find neighbours in z
101- if mesh . dim == 3 :
94+ if len ( neighbors ) == 3 :
10295 all_neighbors_z = []
103- z_adjacent = np .ones (max_index + 1 , dtype = "int" ) * - 1
104- z_adjacent_backward = np .ones (max_index + 1 , dtype = "int" ) * - 1
10596
106- z_adjacent [ z_neighbors [:, 0 ]] = z_neighbors [:, 1 ]
107- z_adjacent_backward [ z_neighbors_backward [:, 0 ]] = z_neighbors_backward [:, 1 ]
97+ all_neighbors_z += [ neighbors [ 2 ] ]
98+ all_neighbors_z += [ neighbors_backwards [ 2 ] ]
10899
109- all_neighbors_z += [z_neighbors ]
110- all_neighbors_z += [z_neighbors_backward ]
111-
112- all_neighbors_z += [np .c_ [all_neighbors [:, 0 ], z_adjacent [all_neighbors [:, 1 ]]]]
113- all_neighbors_z += [np .c_ [all_neighbors [:, 1 ], z_adjacent [all_neighbors [:, 0 ]]]]
100+ all_neighbors_z += [
101+ np .c_ [all_neighbors [:, 0 ], adjacent [2 ][all_neighbors [:, 1 ]]]
102+ ]
103+ all_neighbors_z += [
104+ np .c_ [all_neighbors [:, 1 ], adjacent [2 ][all_neighbors [:, 0 ]]]
105+ ]
114106
115107 all_neighbors_z += [
116- np .c_ [all_neighbors [:, 0 ], z_adjacent_backward [all_neighbors [:, 1 ]]]
108+ np .c_ [all_neighbors [:, 0 ], adjacent_backwards [ 2 ] [all_neighbors [:, 1 ]]]
117109 ]
118110 all_neighbors_z += [
119- np .c_ [all_neighbors [:, 1 ], z_adjacent_backward [all_neighbors [:, 0 ]]]
111+ np .c_ [all_neighbors [:, 1 ], adjacent_backwards [ 2 ] [all_neighbors [:, 0 ]]]
120112 ]
121113
122114 # Stack all and keep only unique pairs
@@ -131,6 +123,39 @@ def cell_neighbors(mesh: TreeMesh) -> np.ndarray:
131123 return all_neighbors
132124
133125
126+ def cell_adjacent (neighbors : list [np .ndarray ]) -> list [np .ndarray ]:
127+ """Find all adjacent cells (corners) from cell neighbor array."""
128+
129+ dim = len (neighbors )
130+ max_index = np .max (neighbors )
131+ corners = - 1 * np .ones ((dim , max_index + 1 ), dtype = "int" )
132+
133+ corners [0 , neighbors [1 ][:, 0 ]] = neighbors [1 ][:, 1 ]
134+ corners [1 , neighbors [0 ][:, 1 ]] = neighbors [0 ][:, 0 ]
135+ if dim == 3 :
136+ corners [2 , neighbors [2 ][:, 0 ]] = neighbors [2 ][:, 1 ]
137+
138+ return [np .array (k ) for k in corners .tolist ()]
139+
140+
141+ def cell_neighbors (mesh : TreeMesh ) -> np .ndarray :
142+ """Find all cell neighbors in a TreeMesh."""
143+
144+ neighbors = []
145+ neighbors .append (cell_neighbors_along_axis (mesh , "x" ))
146+ neighbors .append (cell_neighbors_along_axis (mesh , "y" ))
147+ if mesh .dim == 3 :
148+ neighbors .append (cell_neighbors_along_axis (mesh , "z" ))
149+
150+ neighbors_backwards = [np .fliplr (k ) for k in neighbors ]
151+ corners = cell_adjacent (neighbors )
152+ corners_backwards = cell_adjacent (neighbors_backwards )
153+
154+ return collect_all_neighbors (
155+ neighbors , neighbors_backwards , corners , corners_backwards
156+ )
157+
158+
134159def rotate_xz_2d (mesh : TreeMesh , phi : np .ndarray ) -> ssp .csr_matrix :
135160 """
136161 Create a 2d ellipsoidal rotation matrix for the xz plane.
0 commit comments