Skip to content

Commit 80627bf

Browse files
majosminducer
authored andcommitted
combine/optimize _match_vertices and find_point_permutation implementations (#458, part 2)
* optimize vertex matching in periodicity processing * use np.median instead of approximating with np.histogram turns out to be just as fast * combine _match_vertices and mesh.tools.find_point_permutation into find_point_to_point_mapping * add tests for find_point_to_point_mapping * use find_point_to_point_mapping instead of find_point_permutation
1 parent 69bcd5a commit 80627bf

4 files changed

Lines changed: 291 additions & 128 deletions

File tree

meshmode/discretization/connection/direct.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -422,12 +422,17 @@ def _resample_point_pick_indices(self, to_group_index: int, ibatch_index: int,
422422
ibatch = self.groups[to_group_index].batches[ibatch_index]
423423
from_grp = self.from_discr.groups[ibatch.from_group_index]
424424

425-
from meshmode.mesh.tools import find_point_permutation
426-
return find_point_permutation(
427-
targets=ibatch.result_unit_nodes,
428-
permutees=from_grp.unit_nodes,
425+
from meshmode.mesh.tools import find_point_to_point_mapping
426+
src_idx_to_tgt_idx = find_point_to_point_mapping(
427+
src_points=ibatch.result_unit_nodes,
428+
tgt_points=from_grp.unit_nodes,
429429
tol_multiplier=tol_multiplier)
430430

431+
return (
432+
src_idx_to_tgt_idx
433+
if np.all(src_idx_to_tgt_idx >= 0)
434+
else None)
435+
431436
@keyed_memoize_method(lambda actx, to_group_index, ibatch_index,
432437
tol_multiplier=None: (to_group_index, ibatch_index, tol_multiplier))
433438
def _frozen_resample_point_pick_indices(self, actx: ArrayContext,

meshmode/mesh/processing.py

Lines changed: 31 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from collections.abc import Callable, Mapping, Sequence
2626
from dataclasses import dataclass, replace
2727
from functools import reduce
28-
from typing import Any, Literal
28+
from typing import Literal
29+
from warnings import warn
2930

3031
import numpy as np
3132
import numpy.linalg as la
@@ -45,7 +46,7 @@
4546
_FaceIDs,
4647
make_mesh,
4748
)
48-
from meshmode.mesh.tools import AffineMap, find_point_permutation
49+
from meshmode.mesh.tools import AffineMap, find_point_to_point_mapping
4950

5051

5152
__doc__ = """
@@ -588,18 +589,18 @@ def evec(i: int) -> np.ndarray:
588589
result[i] = 1
589590
return result
590591

591-
def unpack_single(ary: np.ndarray | None) -> np.ndarray:
592+
def unpack_single(ary: np.ndarray | None) -> int:
592593
assert ary is not None
593594
item, = ary
594595
return item
595596

596-
base_vertex_index = unpack_single(find_point_permutation(
597-
targets=-np.ones(grp.dim),
598-
permutees=grp.vertex_unit_coordinates().T))
597+
base_vertex_index = unpack_single(find_point_to_point_mapping(
598+
src_points=-np.ones(grp.dim).reshape(-1, 1),
599+
tgt_points=grp.vertex_unit_coordinates().T))
599600
spanning_vertex_indices = [
600-
unpack_single(find_point_permutation(
601-
targets=-np.ones(grp.dim) + 2 * evec(i),
602-
permutees=grp.vertex_unit_coordinates().T))
601+
unpack_single(find_point_to_point_mapping(
602+
src_points=(-np.ones(grp.dim) + 2 * evec(i)).reshape(-1, 1),
603+
tgt_points=grp.vertex_unit_coordinates().T))
603604
for i in range(grp.dim)
604605
]
605606

@@ -747,11 +748,10 @@ def _get_tensor_product_element_flip_matrix_and_vertex_permutation(
747748
unit_flip_matrix,
748749
grp.vertex_unit_coordinates().T)
749750

750-
vertex_permutation_to = find_point_permutation(
751-
targets=flipped_vertices,
752-
permutees=grp.vertex_unit_coordinates().T,
753-
)
754-
if vertex_permutation_to is None:
751+
vertex_permutation_to = find_point_to_point_mapping(
752+
src_points=flipped_vertices,
753+
tgt_points=grp.vertex_unit_coordinates().T)
754+
if np.any(vertex_permutation_to < 0):
755755
raise RuntimeError("flip permutation was not found")
756756

757757
flipped_unit_nodes = np.einsum("ij,jn->in", unit_flip_matrix, grp.unit_nodes)
@@ -1020,85 +1020,6 @@ def split_mesh_groups(
10201020
# }}}
10211021

10221022

1023-
# {{{ vertex matching
1024-
1025-
def _match_vertices(
1026-
mesh: Mesh,
1027-
src_vertex_indices: np.ndarray,
1028-
tgt_vertex_indices: np.ndarray, *,
1029-
aff_map: AffineMap | None = None,
1030-
tol: float = 1e-12,
1031-
use_tree: bool | None = None) -> np.ndarray:
1032-
if mesh.vertices is None:
1033-
raise ValueError("Mesh must have vertices")
1034-
1035-
if aff_map is None:
1036-
aff_map = AffineMap()
1037-
1038-
if use_tree is None:
1039-
# Empirically, the tree version becomes faster at 2**13.
1040-
# The temporary (displacements) below at that size requires
1041-
# 1.6GB, which seems like a lot. Capping at 2**11 instead,
1042-
# which requires a more reasonable 100M.
1043-
use_tree = len(tgt_vertex_indices) >= 2**11
1044-
1045-
src_vertices = mesh.vertices[:, src_vertex_indices]
1046-
tgt_vertices = mesh.vertices[:, tgt_vertex_indices]
1047-
1048-
mapped_src_vertices = aff_map(src_vertices)
1049-
1050-
if use_tree:
1051-
tgt_vertex_bboxes = np.stack((
1052-
tgt_vertices - tol,
1053-
tgt_vertices + tol))
1054-
1055-
from pytools.spatial_btree import SpatialBinaryTreeBucket
1056-
tree = SpatialBinaryTreeBucket(
1057-
np.min(tgt_vertex_bboxes[0], axis=1),
1058-
np.max(tgt_vertex_bboxes[1], axis=1))
1059-
for ivertex in range(len(tgt_vertex_indices)):
1060-
tree.insert(ivertex, tgt_vertex_bboxes[:, :, ivertex])
1061-
1062-
matched_tgt_vertices: np.ndarray[tuple[int, ...], np.dtype[Any]] \
1063-
= np.full(len(src_vertex_indices), -1)
1064-
for ivertex in range(len(src_vertex_indices)):
1065-
mapped_src_vertex = mapped_src_vertices[:, ivertex]
1066-
matches = np.array(list(tree.generate_matches(mapped_src_vertex)))
1067-
match_bboxes = tgt_vertex_bboxes[:, :, matches]
1068-
in_bbox = np.all(
1069-
(mapped_src_vertex[:, np.newaxis] >= match_bboxes[0, :, :])
1070-
& (mapped_src_vertex[:, np.newaxis] <= match_bboxes[1, :, :]),
1071-
axis=0)
1072-
candidate_indices = matches[in_bbox]
1073-
if len(candidate_indices) == 0:
1074-
continue
1075-
displacements = (
1076-
mapped_src_vertex.reshape(-1, 1)
1077-
- tgt_vertices[:, candidate_indices])
1078-
distances_sq = np.sum(displacements**2, axis=0)
1079-
matched_tgt_vertices[ivertex] = (
1080-
tgt_vertex_indices[candidate_indices[np.argmin(distances_sq)]])
1081-
1082-
else:
1083-
displacements = (
1084-
mapped_src_vertices.reshape(mesh.dim, -1, 1)
1085-
- tgt_vertices.reshape(mesh.dim, 1, -1))
1086-
distances_sq = np.sum(displacements**2, axis=0)
1087-
1088-
vertex_indices, = np.indices((len(src_vertex_indices),))
1089-
min_distance_sq_indices = np.argmin(distances_sq, axis=1)
1090-
min_distances_sq = distances_sq[vertex_indices, min_distance_sq_indices]
1091-
1092-
matched_tgt_vertices = np.where(
1093-
min_distances_sq < tol**2,
1094-
tgt_vertex_indices[min_distance_sq_indices],
1095-
-1)
1096-
1097-
return matched_tgt_vertices
1098-
1099-
# }}}
1100-
1101-
11021023
# {{{ boundary face matching
11031024

11041025
@dataclass(frozen=True)
@@ -1170,8 +1091,8 @@ def _get_face_vertex_indices(mesh: Mesh, face_ids: _FaceIDs) -> np.ndarray:
11701091

11711092

11721093
def _match_boundary_faces(
1173-
mesh: Mesh, bdry_pair_mapping: BoundaryPairMapping, tol: float, *,
1174-
use_tree: bool | None = None) -> tuple[_FaceIDs, _FaceIDs]:
1094+
mesh: Mesh, bdry_pair_mapping: BoundaryPairMapping, tol: float,
1095+
) -> tuple[_FaceIDs, _FaceIDs]:
11751096
"""
11761097
Given a :class:`BoundaryPairMapping` *bdry_pair_mapping*, return the
11771098
correspondence between faces of the two boundaries (expressed as a pair of
@@ -1182,8 +1103,6 @@ def _match_boundary_faces(
11821103
whose faces are to be matched.
11831104
:arg tol: The allowed tolerance between the transformed vertex coordinates of
11841105
the first boundary and the vertex coordinates of the second boundary.
1185-
:arg use_tree: Optional argument indicating whether to use a spatial binary
1186-
search tree or a (quadratic) numpy algorithm when matching vertices.
11871106
:returns: A pair of :class:`meshmode.mesh._FaceIDs`, each having a number of
11881107
entries equal to the number of faces in the boundary, that represents the
11891108
correspondence between the two boundaries' faces. The first element in the
@@ -1217,12 +1136,15 @@ def _match_boundary_faces(
12171136
bdry_n_vertex_indices = np.unique(bdry_n_face_vertex_indices)
12181137
bdry_n_vertex_indices = bdry_n_vertex_indices[bdry_n_vertex_indices >= 0]
12191138

1220-
matched_bdry_n_vertex_indices = _match_vertices(
1221-
mesh, bdry_m_vertex_indices, bdry_n_vertex_indices,
1222-
aff_map=bdry_pair_mapping.aff_map, tol=tol, use_tree=use_tree)
1139+
bdry_m_vertices = mesh.vertices[:, bdry_m_vertex_indices]
1140+
bdry_n_vertices = mesh.vertices[:, bdry_n_vertex_indices]
1141+
1142+
m_idx_to_n_idx = find_point_to_point_mapping(
1143+
bdry_pair_mapping.aff_map(bdry_m_vertices),
1144+
bdry_n_vertices)
12231145

12241146
unmatched_bdry_m_vertex_indices = bdry_m_vertex_indices[
1225-
np.where(matched_bdry_n_vertex_indices < 0)[0]]
1147+
np.where(m_idx_to_n_idx < 0)[0]]
12261148
nunmatched = len(unmatched_bdry_m_vertex_indices)
12271149
if nunmatched > 0:
12281150
vertices = mesh.vertices[:, unmatched_bdry_m_vertex_indices]
@@ -1235,6 +1157,9 @@ def _match_boundary_faces(
12351157
for i in range(min(nunmatched, 10))])
12361158
+ f"\n...\n({nunmatched-10} more omitted.)" if nunmatched > 10 else "")
12371159

1160+
matched_bdry_n_vertex_indices = bdry_n_vertex_indices[
1161+
m_idx_to_n_idx]
1162+
12381163
from meshmode.mesh import _concatenate_face_ids
12391164
face_ids = _concatenate_face_ids([bdry_m_face_ids, bdry_n_face_ids])
12401165

@@ -1291,13 +1216,15 @@ def glue_mesh_boundaries(
12911216
coordinates of the first boundary and the vertex coordinates of the second
12921217
boundary when attempting to match the two. Pass at most one mapping for each
12931218
unique (order-independent) pair of boundaries.
1294-
:arg use_tree: Optional argument indicating whether to use a spatial binary
1295-
search tree or a (quadratic) numpy algorithm when matching vertices.
12961219
"""
12971220
if any(grp.vertex_indices is None for grp in mesh.groups):
12981221
raise ValueError(
12991222
"gluing mesh boundaries requires 'vertex_indices' in all groups")
13001223

1224+
if use_tree is not None:
1225+
warn("Passing 'use_tree' is deprecated and will be removed "
1226+
"in Q3 2025.", DeprecationWarning, stacklevel=2)
1227+
13011228
glued_btags = {
13021229
btag
13031230
for mapping, _ in bdry_pair_mappings_and_tols
@@ -1320,7 +1247,7 @@ def glue_mesh_boundaries(
13201247
glued_btag_pairs.add(btag_pair)
13211248

13221249
face_id_pairs_for_mapping = [
1323-
_match_boundary_faces(mesh, mapping, tol, use_tree=use_tree)
1250+
_match_boundary_faces(mesh, mapping, tol)
13241251
for mapping, tol in bdry_pair_mappings_and_tols]
13251252

13261253
facial_adjacency_groups = []

0 commit comments

Comments
 (0)