Skip to content

Commit b6571c1

Browse files
committed
feat: allow gathering neighbors from another index list
1 parent 787ac33 commit b6571c1

1 file changed

Lines changed: 46 additions & 25 deletions

File tree

pytential/linalg/proxy.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -689,24 +689,39 @@ def __call__(self,
689689

690690
def gather_cluster_neighbor_points(
691691
actx: PyOpenCLArrayContext,
692-
pxy: ProxyClusterGeometryData, *,
692+
pxy: ProxyClusterGeometryData,
693+
tgtindex: IndexList | None = None,
694+
*,
693695
max_particles_in_box: int | None = None) -> IndexList:
694-
"""Generate a set of neighboring points for each cluster of points in
695-
*pxy*. Neighboring points of a cluster :math:`i` are defined
696-
as all the points inside the proxy ball :math:`i` that do not also
697-
belong to the cluster itself.
696+
r"""Generate a set of neighboring points for each cluster of points in *pxy*.
697+
698+
Neighboring points of a cluster :math:`i` are defined as all the points
699+
from *tgtindex* that are inside the proxy ball :math:`i` but outside the
700+
cluster itself. For example, given a cluster with radius :math:`r_s` and
701+
proxy radius :math:`r_p > r_s`, then we gather all points such that
702+
:math:`r_s < \|\mathbf{x}\| <= r_p`.
698703
"""
699704

705+
srcindex = pxy.srcindex
706+
if tgtindex is None:
707+
tgtindex = srcindex
708+
709+
nclusters = srcindex.nclusters
710+
if tgtindex.nclusters != nclusters:
711+
raise ValueError("'tgtindex' has a different number of clusters: "
712+
f"'{tgtindex.nclusters}' (expected {nclusters})")
713+
700714
if max_particles_in_box is None:
701715
max_particles_in_box = _DEFAULT_MAX_PARTICLES_IN_BOX
702716

703-
from pytential.source import LayerPotentialSourceBase
704-
705717
dofdesc = pxy.dofdesc
706718
lpot_source = pxy.places.get_geometry(dofdesc.geometry)
707-
assert isinstance(lpot_source, LayerPotentialSourceBase)
708-
709719
discr = pxy.places.get_discretization(dofdesc.geometry, dofdesc.discr_stage)
720+
721+
assert (
722+
dofdesc.discr_stage is None
723+
or isinstance(lpot_source, QBXLayerPotentialSource)
724+
), (dofdesc, type(lpot_source))
710725
assert isinstance(discr, Discretization)
711726

712727
# {{{ get only sources in the current cluster set
@@ -734,18 +749,23 @@ def prg() -> lp.ExecutorBase:
734749

735750
return knl.executor(actx.context)
736751

737-
_, (sources,) = prg()(actx.queue,
752+
_, (targets,) = prg()(actx.queue,
738753
ary=flatten(discr.nodes(), actx, leaf_class=DOFArray),
739-
srcindices=pxy.srcindex.indices)
754+
srcindices=tgtindex.indices)
740755

741756
# }}}
742757

743758
# {{{ perform area query
744759

745760
from pytential.qbx.utils import tree_code_container
746-
tcc = tree_code_container(lpot_source._setup_actx)
747761

748-
tree, _ = tcc.build_tree()(actx.queue, sources,
762+
# NOTE: use the base source's actx for caching the code -- that has
763+
# the best chance of surviving even when updating the lpot_source
764+
setup_actx = discr._setup_actx
765+
assert isinstance(setup_actx, PyOpenCLArrayContext)
766+
767+
tcc = tree_code_container(setup_actx)
768+
tree, _ = tcc.build_tree()(actx.queue, targets,
749769
max_particles_in_box=max_particles_in_box)
750770
query, _ = tcc.build_area_query()(actx.queue, tree, pxy.centers, pxy.radii)
751771

@@ -758,10 +778,10 @@ def prg() -> lp.ExecutorBase:
758778

759779
pxycenters = actx.to_numpy(pxy.centers)
760780
pxyradii = actx.to_numpy(pxy.radii)
761-
srcindex = pxy.srcindex
762781

763-
nbrindices: np.ndarray = np.empty(srcindex.nclusters, dtype=object)
764-
for icluster in range(srcindex.nclusters):
782+
eps = 100 * np.finfo(pxyradii.dtype).eps
783+
nbrindices = np.empty(nclusters, dtype=object)
784+
for icluster in range(nclusters):
765785
# get list of boxes intersecting the current ball
766786
istart = query.leaves_near_ball_starts[icluster]
767787
iend = query.leaves_near_ball_starts[icluster + 1]
@@ -780,16 +800,17 @@ def prg() -> lp.ExecutorBase:
780800
isources = tree.user_source_ids[isources]
781801

782802
# get nodes inside the ball but outside the current cluster
783-
# FIXME: this assumes that only the points in `pxy.secindex` should
784-
# count as neighbors, not all the nodes in the discretization.
785-
# FIXME: it also assumes that all the indices are sorted?
786803
center = pxycenters[:, icluster].reshape(-1, 1)
787-
radius = pxyradii[icluster]
788-
mask = ((la.norm(nodes - center, axis=0) < radius)
789-
& ((isources < srcindex.starts[icluster])
790-
| (srcindex.starts[icluster + 1] <= isources)))
791-
792-
nbrindices[icluster] = srcindex.indices[isources[mask]]
804+
radii = la.norm(nodes - center, axis=0) - eps
805+
mask = (
806+
(radii <= pxyradii[icluster])
807+
& ((isources < tgtindex.starts[icluster])
808+
| (tgtindex.starts[icluster + 1] <= isources)))
809+
810+
nbrindices[icluster] = tgtindex.indices[isources[mask]]
811+
if nbrindices[icluster].size == 0:
812+
logger.warning("Cluster '%d' has no neighbors. You might need to "
813+
"increase the proxy 'radius_factor'.", icluster)
793814

794815
# }}}
795816

0 commit comments

Comments
 (0)