@@ -689,24 +689,39 @@ def __call__(self,
689689
690690def gather_cluster_neighbor_points (
691691 actx : PyOpenCLArrayContext ,
692- pxy : ProxyClusterGeometryData , * ,
692+ pxy : ProxyClusterGeometryData ,
693+ tgtindex : IndexList | None = None ,
694+ * ,
693695 max_particles_in_box : int | None = None ) -> IndexList :
694- """Generate a set of neighboring points for each cluster of points in
695- *pxy*. Neighboring points of a cluster :math:`i` are defined
696- as all the points inside the proxy ball :math:`i` that do not also
697- belong to the cluster itself.
696+ r"""Generate a set of neighboring points for each cluster of points in *pxy*.
697+
698+ Neighboring points of a cluster :math:`i` are defined as all the points
699+ from *tgtindex* that are inside the proxy ball :math:`i` but outside the
700+ cluster itself. For example, given a cluster with radius :math:`r_s` and
701+ proxy radius :math:`r_p > r_s`, then we gather all points such that
702+ :math:`r_s < \|\mathbf{x}\| <= r_p`.
698703 """
699704
705+ srcindex = pxy .srcindex
706+ if tgtindex is None :
707+ tgtindex = srcindex
708+
709+ nclusters = srcindex .nclusters
710+ if tgtindex .nclusters != nclusters :
711+ raise ValueError ("'tgtindex' has a different number of clusters: "
712+ f"'{ tgtindex .nclusters } ' (expected { nclusters } )" )
713+
700714 if max_particles_in_box is None :
701715 max_particles_in_box = _DEFAULT_MAX_PARTICLES_IN_BOX
702716
703- from pytential .source import LayerPotentialSourceBase
704-
705717 dofdesc = pxy .dofdesc
706718 lpot_source = pxy .places .get_geometry (dofdesc .geometry )
707- assert isinstance (lpot_source , LayerPotentialSourceBase )
708-
709719 discr = pxy .places .get_discretization (dofdesc .geometry , dofdesc .discr_stage )
720+
721+ assert (
722+ dofdesc .discr_stage is None
723+ or isinstance (lpot_source , QBXLayerPotentialSource )
724+ ), (dofdesc , type (lpot_source ))
710725 assert isinstance (discr , Discretization )
711726
712727 # {{{ get only sources in the current cluster set
@@ -734,18 +749,23 @@ def prg() -> lp.ExecutorBase:
734749
735750 return knl .executor (actx .context )
736751
737- _ , (sources ,) = prg ()(actx .queue ,
752+ _ , (targets ,) = prg ()(actx .queue ,
738753 ary = flatten (discr .nodes (), actx , leaf_class = DOFArray ),
739- srcindices = pxy . srcindex .indices )
754+ srcindices = tgtindex .indices )
740755
741756 # }}}
742757
743758 # {{{ perform area query
744759
745760 from pytential .qbx .utils import tree_code_container
746- tcc = tree_code_container (lpot_source ._setup_actx )
747761
748- tree , _ = tcc .build_tree ()(actx .queue , sources ,
762+ # NOTE: use the base source's actx for caching the code -- that has
763+ # the best chance of surviving even when updating the lpot_source
764+ setup_actx = discr ._setup_actx
765+ assert isinstance (setup_actx , PyOpenCLArrayContext )
766+
767+ tcc = tree_code_container (setup_actx )
768+ tree , _ = tcc .build_tree ()(actx .queue , targets ,
749769 max_particles_in_box = max_particles_in_box )
750770 query , _ = tcc .build_area_query ()(actx .queue , tree , pxy .centers , pxy .radii )
751771
@@ -758,10 +778,10 @@ def prg() -> lp.ExecutorBase:
758778
759779 pxycenters = actx .to_numpy (pxy .centers )
760780 pxyradii = actx .to_numpy (pxy .radii )
761- srcindex = pxy .srcindex
762781
763- nbrindices : np .ndarray = np .empty (srcindex .nclusters , dtype = object )
764- for icluster in range (srcindex .nclusters ):
782+ eps = 100 * np .finfo (pxyradii .dtype ).eps
783+ nbrindices = np .empty (nclusters , dtype = object )
784+ for icluster in range (nclusters ):
765785 # get list of boxes intersecting the current ball
766786 istart = query .leaves_near_ball_starts [icluster ]
767787 iend = query .leaves_near_ball_starts [icluster + 1 ]
@@ -780,16 +800,17 @@ def prg() -> lp.ExecutorBase:
780800 isources = tree .user_source_ids [isources ]
781801
782802 # get nodes inside the ball but outside the current cluster
783- # FIXME: this assumes that only the points in `pxy.secindex` should
784- # count as neighbors, not all the nodes in the discretization.
785- # FIXME: it also assumes that all the indices are sorted?
786803 center = pxycenters [:, icluster ].reshape (- 1 , 1 )
787- radius = pxyradii [icluster ]
788- mask = ((la .norm (nodes - center , axis = 0 ) < radius )
789- & ((isources < srcindex .starts [icluster ])
790- | (srcindex .starts [icluster + 1 ] <= isources )))
791-
792- nbrindices [icluster ] = srcindex .indices [isources [mask ]]
804+ radii = la .norm (nodes - center , axis = 0 ) - eps
805+ mask = (
806+ (radii <= pxyradii [icluster ])
807+ & ((isources < tgtindex .starts [icluster ])
808+ | (tgtindex .starts [icluster + 1 ] <= isources )))
809+
810+ nbrindices [icluster ] = tgtindex .indices [isources [mask ]]
811+ if nbrindices [icluster ].size == 0 :
812+ logger .warning ("Cluster '%d' has no neighbors. You might need to "
813+ "increase the proxy 'radius_factor'." , icluster )
793814
794815 # }}}
795816
0 commit comments