@@ -38,52 +38,32 @@ def cell_neighbors_along_axis(mesh: TreeMesh, axis: str) -> np.ndarray:
3838 return np .sort (stencil_indices , axis = 1 )
3939
4040
41- def cell_neighbors (mesh : TreeMesh ) -> np .ndarray :
42- """Find all cell neighbors in a TreeMesh."""
43-
44- x_neighbors = cell_neighbors_along_axis (mesh , "x" )
45- x_neighbors_backward = np .fliplr (x_neighbors )
46- y_neighbors = cell_neighbors_along_axis (mesh , "y" )
47- y_neighbors_backward = np .fliplr (y_neighbors )
48- max_index = np .max ([x_neighbors .max (), y_neighbors .max ()])
49- if mesh .dim == 3 :
50- z_neighbors = cell_neighbors_along_axis (mesh , "z" )
51- z_neighbors_backward = np .fliplr (z_neighbors )
52- max_index = np .max ([max_index , z_neighbors .max ()])
53-
41+ def collect_all_neighbors (neighbors , neighbors_backwards , corners , corners_backwards ):
5442 all_neighbors = [] # Store
55- x_adjacent = np .ones (max_index + 1 , dtype = "int" ) * - 1
56- y_adjacent = np .ones (max_index + 1 , dtype = "int" ) * - 1
57- x_adjacent_backward = np .ones (max_index + 1 , dtype = "int" ) * - 1
58- y_adjacent_backward = np .ones (max_index + 1 , dtype = "int" ) * - 1
59-
60- x_adjacent [y_neighbors [:, 0 ]] = y_neighbors [:, 1 ]
61- y_adjacent [x_neighbors [:, 1 ]] = x_neighbors [:, 0 ]
6243
63- x_adjacent_backward [ y_neighbors_backward [:, 0 ]] = y_neighbors_backward [:, 1 ]
64- y_adjacent_backward [ x_neighbors_backward [:, 1 ]] = x_neighbors_backward [:, 0 ]
44+ all_neighbors += [ neighbors [ 0 ] ]
45+ all_neighbors += [ neighbors [ 1 ] ]
6546
66- all_neighbors += [x_neighbors ]
67- all_neighbors += [y_neighbors ]
47+ all_neighbors += [np . c_ [ neighbors [ 0 ][:, 0 ], corners [ 0 ][ neighbors [ 0 ][:, 1 ]]] ]
48+ all_neighbors += [np . c_ [ neighbors [ 0 ][:, 1 ], corners [ 0 ][ neighbors [ 0 ][:, 0 ]]] ]
6849
69- all_neighbors += [np .c_ [x_neighbors [:, 0 ], x_adjacent [x_neighbors [:, 1 ]]]]
70- all_neighbors += [np .c_ [x_neighbors [:, 1 ], x_adjacent [x_neighbors [:, 0 ]]]]
71-
72- all_neighbors += [np .c_ [y_adjacent [y_neighbors [:, 0 ]], y_neighbors [:, 1 ]]]
73- all_neighbors += [np .c_ [y_adjacent [y_neighbors [:, 1 ]], y_neighbors [:, 0 ]]]
50+ all_neighbors += [np .c_ [corners [1 ][neighbors [1 ][:, 0 ]], neighbors [1 ][:, 1 ]]]
51+ all_neighbors += [np .c_ [corners [1 ][neighbors [1 ][:, 1 ]], neighbors [1 ][:, 0 ]]]
7452
7553 # Repeat backward for Treemesh
76- all_neighbors += [x_neighbors_backward ]
77- all_neighbors += [y_neighbors_backward ]
54+ all_neighbors += [neighbors_backwards [ 0 ] ]
55+ all_neighbors += [neighbors_backwards [ 1 ] ]
7856
7957 all_neighbors += [
8058 np .c_ [
81- x_neighbors_backward [:, 0 ], x_adjacent_backward [x_neighbors_backward [:, 1 ]]
59+ neighbors_backwards [0 ][:, 0 ],
60+ corners_backwards [0 ][neighbors_backwards [0 ][:, 1 ]],
8261 ]
8362 ]
8463 all_neighbors += [
8564 np .c_ [
86- x_neighbors_backward [:, 1 ], x_adjacent_backward [x_neighbors_backward [:, 0 ]]
65+ neighbors_backwards [0 ][:, 1 ],
66+ corners_backwards [0 ][neighbors_backwards [0 ][:, 0 ]],
8767 ]
8868 ]
8969
@@ -97,25 +77,20 @@ def cell_neighbors(mesh: TreeMesh) -> np.ndarray:
9777 ]
9878
9979 # Use all the neighbours on the xy plane to find neighbours in z
100- if mesh . dim == 3 :
80+ if len ( neighbors ) == 3 :
10181 all_neighbors_z = []
102- z_adjacent = np .ones (max_index + 1 , dtype = "int" ) * - 1
103- z_adjacent_backward = np .ones (max_index + 1 , dtype = "int" ) * - 1
104-
105- z_adjacent [z_neighbors [:, 0 ]] = z_neighbors [:, 1 ]
106- z_adjacent_backward [z_neighbors_backward [:, 0 ]] = z_neighbors_backward [:, 1 ]
10782
108- all_neighbors_z += [z_neighbors ]
109- all_neighbors_z += [z_neighbors_backward ]
83+ all_neighbors_z += [neighbors [ 2 ] ]
84+ all_neighbors_z += [neighbors_backwards [ 2 ] ]
11085
111- all_neighbors_z += [np .c_ [all_neighbors [:, 0 ], z_adjacent [all_neighbors [:, 1 ]]]]
112- all_neighbors_z += [np .c_ [all_neighbors [:, 1 ], z_adjacent [all_neighbors [:, 0 ]]]]
86+ all_neighbors_z += [np .c_ [all_neighbors [:, 0 ], corners [ 2 ] [all_neighbors [:, 1 ]]]]
87+ all_neighbors_z += [np .c_ [all_neighbors [:, 1 ], corners [ 2 ] [all_neighbors [:, 0 ]]]]
11388
11489 all_neighbors_z += [
115- np .c_ [all_neighbors [:, 0 ], z_adjacent_backward [all_neighbors [:, 1 ]]]
90+ np .c_ [all_neighbors [:, 0 ], corners_backwards [ 2 ] [all_neighbors [:, 1 ]]]
11691 ]
11792 all_neighbors_z += [
118- np .c_ [all_neighbors [:, 1 ], z_adjacent_backward [all_neighbors [:, 0 ]]]
93+ np .c_ [all_neighbors [:, 1 ], corners_backwards [ 2 ] [all_neighbors [:, 0 ]]]
11994 ]
12095
12196 # Stack all and keep only unique pairs
@@ -130,6 +105,39 @@ def cell_neighbors(mesh: TreeMesh) -> np.ndarray:
130105 return all_neighbors
131106
132107
108+ def cell_adjacent (neighbors : list [np .ndarray ]) -> list [np .ndarray ]:
109+ """Find all cell corners from cell neighbor array."""
110+
111+ dim = len (neighbors )
112+ max_index = np .max (neighbors )
113+ corners = - 1 * np .ones ((dim , max_index + 1 ), dtype = "int" )
114+
115+ corners [0 , neighbors [1 ][:, 0 ]] = neighbors [1 ][:, 1 ]
116+ corners [1 , neighbors [0 ][:, 1 ]] = neighbors [0 ][:, 0 ]
117+ if dim == 3 :
118+ corners [2 , neighbors [2 ][:, 0 ]] = neighbors [2 ][:, 1 ]
119+
120+ return [np .array (k ) for k in corners .tolist ()]
121+
122+
123+ def cell_neighbors (mesh : TreeMesh ) -> np .ndarray :
124+ """Find all cell neighbors in a TreeMesh."""
125+
126+ neighbors = []
127+ neighbors .append (cell_neighbors_along_axis (mesh , "x" ))
128+ neighbors .append (cell_neighbors_along_axis (mesh , "y" ))
129+ if mesh .dim == 3 :
130+ neighbors .append (cell_neighbors_along_axis (mesh , "z" ))
131+
132+ neighbors_backwards = [np .fliplr (k ) for k in neighbors ]
133+ corners = cell_adjacent (neighbors )
134+ corners_backwards = cell_adjacent (neighbors_backwards )
135+
136+ return collect_all_neighbors (
137+ neighbors , neighbors_backwards , corners , corners_backwards
138+ )
139+
140+
133141def rotate_xz_2d (mesh : TreeMesh , phi : np .ndarray ) -> ssp .csr_matrix :
134142 """
135143 Create a 2d ellipsoidal rotation matrix for the xz plane.
0 commit comments