99from numba import njit
1010
1111
12+ @njit
13+ def _apply_pbc (vec , dimensions , half_dimensions ):
14+ """
15+ Apply minimum image convention for periodic boundary conditions.
16+
17+ Args:
18+ vec (np.ndarray):
19+ Vector to wrap.
20+ dimensions (np.ndarray):
21+ Simulation box dimensions.
22+ half_dimensions (np.ndarray):
23+ Half box lengths.
24+
25+ Returns:
26+ np.ndarray:
27+ Wrapped vector.
28+ """
29+ for d in range (3 ):
30+ if vec [d ] > half_dimensions [d ]:
31+ vec [d ] -= dimensions [d ]
32+ elif vec [d ] < - half_dimensions [d ]:
33+ vec [d ] += dimensions [d ]
34+ return vec
35+
36+
1237@njit
1338def _rad_blocking_loop (i_coords , sorted_indices , sorted_distances , coms , dimensions ):
1439 """
@@ -38,54 +63,58 @@ def _rad_blocking_loop(i_coords, sorted_indices, sorted_distances, coms, dimensi
3863 Simulation box dimensions for periodic boundary conditions.
3964
4065 Returns:
41- list[int] :
66+ np.ndarray :
4267 Indices of molecules that belong to the RAD neighbor shell.
4368 """
44- shell = []
45-
4669 n = sorted_indices .shape [0 ]
4770 limit = min (n , 30 )
4871
72+ half_dimensions = 0.5 * dimensions
73+
74+ inv_r2 = 1.0 / (sorted_distances * sorted_distances )
75+
76+ shell = np .empty (limit , dtype = np .int64 )
77+ count = 0
78+
4979 for y in range (limit ):
5080 j_idx = sorted_indices [y ]
5181 r_ij = sorted_distances [y ]
5282 j_coords = coms [j_idx ]
5383
84+ ba = j_coords - i_coords
85+ ba = _apply_pbc (ba , dimensions , half_dimensions )
86+
5487 blocked = False
5588
5689 for z in range (y ):
5790 k_idx = sorted_indices [z ]
5891 r_ik = sorted_distances [z ]
59- k_coords = coms [k_idx ]
6092
61- ba = np .abs (j_coords - i_coords )
62- bc = np .abs (k_coords - i_coords )
63- ac = np .abs (k_coords - j_coords )
93+ if r_ik > r_ij :
94+ continue
6495
65- ba = np .where (ba > 0.5 * dimensions , ba - dimensions , ba )
66- bc = np .where (bc > 0.5 * dimensions , bc - dimensions , bc )
67- ac = np .where (ac > 0.5 * dimensions , ac - dimensions , ac )
96+ k_coords = coms [k_idx ]
6897
69- dist_ba = np .sqrt ((ba ** 2 ).sum ())
70- dist_bc = np .sqrt ((bc ** 2 ).sum ())
71- dist_ac = np .sqrt ((ac ** 2 ).sum ())
98+ ac = k_coords - j_coords
99+ ac = _apply_pbc (ac , dimensions , half_dimensions )
72100
73- costheta = (dist_ac ** 2 - dist_bc ** 2 - dist_ba ** 2 ) / ( - 2 * dist_bc * dist_ba )
101+ dist_ac2 = (ac * ac ). sum ( )
74102
75- if np .isnan (costheta ):
76- break
103+ denom = - 2.0 * r_ik * r_ij
104+ if denom == 0.0 :
105+ continue
77106
78- LHS = (1.0 / r_ij ) ** 2
79- RHS = ((1.0 / r_ik ) ** 2 ) * costheta
107+ costheta = (dist_ac2 - r_ik * r_ik - r_ij * r_ij ) / denom
80108
81- if LHS < RHS :
109+ if inv_r2 [ y ] < inv_r2 [ z ] * costheta :
82110 blocked = True
83111 break
84112
85113 if not blocked :
86- shell .append (j_idx )
114+ shell [count ] = j_idx
115+ count += 1
87116
88- return shell
117+ return shell [: count ]
89118
90119
91120class Search :
@@ -114,16 +143,13 @@ def _update_cache(self, universe):
114143 universe (MDAnalysis.Universe):
115144 MDAnalysis universe object containing the system.
116145 """
117- # Get current frame index (MDAnalysis trajectory)
118146 current_frame = universe .trajectory .ts .frame
119147
120- # Only recompute if frame has changed
121148 if self ._cached_frame == current_frame :
122149 return
123150
124151 fragments = universe .atoms .fragments
125152
126- # Compute COMs once per frame (deterministic snapshot)
127153 coms = np .array ([frag .center_of_mass () for frag in fragments ])
128154
129155 self ._cached_fragments = fragments
@@ -149,9 +175,22 @@ def _get_distances(self, coms, i_coords, dimensions):
149175 Distances from the central molecule to all fragments.
150176 """
151177 delta = coms - i_coords
152- delta = np .where (delta > 0.5 * dimensions , delta - dimensions , delta )
153- delta = np .where (delta < - 0.5 * dimensions , delta + dimensions , delta )
154- return np .sqrt ((delta ** 2 ).sum (axis = 1 ))
178+
179+ half_dimensions = 0.5 * dimensions
180+
181+ for d in range (3 ):
182+ delta [:, d ] = np .where (
183+ delta [:, d ] > half_dimensions [d ],
184+ delta [:, d ] - dimensions [d ],
185+ delta [:, d ],
186+ )
187+ delta [:, d ] = np .where (
188+ delta [:, d ] < - half_dimensions [d ],
189+ delta [:, d ] + dimensions [d ],
190+ delta [:, d ],
191+ )
192+
193+ return np .sqrt ((delta * delta ).sum (axis = 1 ))
155194
156195 def get_RAD_neighbors (self , universe , mol_id ):
157196 """
@@ -164,10 +203,9 @@ def get_RAD_neighbors(self, universe, mol_id):
164203 Index of the central molecule.
165204
166205 Returns:
167- list[int] :
206+ np.ndarray :
168207 Indices of neighboring molecules identified via the RAD method.
169208 """
170- # Ensure cache corresponds to current frame
171209 self ._update_cache (universe )
172210
173211 fragments = self ._cached_fragments
@@ -178,7 +216,6 @@ def get_RAD_neighbors(self, universe, mol_id):
178216
179217 central_position = coms [mol_id ]
180218
181- # Distances computed from same COM snapshot
182219 distances_array = self ._get_distances (coms , central_position , dimensions )
183220
184221 indices = np .arange (number_molecules )
@@ -187,7 +224,6 @@ def get_RAD_neighbors(self, universe, mol_id):
187224 filtered_indices = indices [mask ]
188225 filtered_distances = distances_array [mask ]
189226
190- # Stable sort to avoid ordering ambiguity
191227 order = np .argsort (filtered_distances , kind = "mergesort" )
192228
193229 sorted_indices = filtered_indices [order ]
@@ -219,7 +255,7 @@ def get_grid_neighbors(self, universe, mol_id, highest_level):
219255 Molecule level ("united_atom" or other).
220256
221257 Returns:
222- list[int] :
258+ np.ndarray :
223259 Fragment indices of neighboring molecules.
224260 """
225261 fragments = universe .atoms .fragments
0 commit comments