@@ -58,22 +58,18 @@ def _rad_blocking_loop(i_coords, sorted_indices, sorted_distances, coms, dimensi
5858 r_ik = sorted_distances [z ]
5959 k_coords = coms [k_idx ]
6060
61- # Compute coordinate differences
6261 ba = np .abs (j_coords - i_coords )
6362 bc = np .abs (k_coords - i_coords )
6463 ac = np .abs (k_coords - j_coords )
6564
66- # Apply periodic boundary conditions
6765 ba = np .where (ba > 0.5 * dimensions , ba - dimensions , ba )
6866 bc = np .where (bc > 0.5 * dimensions , bc - dimensions , bc )
6967 ac = np .where (ac > 0.5 * dimensions , ac - dimensions , ac )
7068
71- # Compute distances
7269 dist_ba = np .sqrt ((ba ** 2 ).sum ())
7370 dist_bc = np .sqrt ((bc ** 2 ).sum ())
7471 dist_ac = np .sqrt ((ac ** 2 ).sum ())
7572
76- # Cosine of angle jik
7773 costheta = (dist_ac ** 2 - dist_bc ** 2 - dist_ba ** 2 ) / (- 2 * dist_bc * dist_ba )
7874
7975 if np .isnan (costheta ):
@@ -101,25 +97,39 @@ def __init__(self):
10197 """
10298 Initialize the Search class.
10399
104- This class currently serves as a container for neighbor search
105- methods operating on an MDAnalysis universe.
100+ This class includes frame-safe caching of fragment COMs and
101+ system dimensions to avoid recomputation while preserving
102+ identical results to the original implementation.
106103 """
107- self ._universe = None
108- self ._mol_id = None
104+ self ._cached_frame = None
105+ self ._cached_fragments = None
106+ self ._cached_coms = None
107+ self ._cached_dimensions = None
109108
110- def _get_fragment_coms (self , universe ):
109+ def _update_cache (self , universe ):
111110 """
112- Precompute center of mass for each molecular fragment .
111+ Update cached MDAnalysis data if the simulation frame has changed .
113112
114113 Args:
115114 universe (MDAnalysis.Universe):
116115 MDAnalysis universe object containing the system.
117-
118- Returns:
119- np.ndarray:
120- Array of shape (n_fragments, 3) containing COM coordinates.
121116 """
122- return np .array ([frag .center_of_mass () for frag in universe .atoms .fragments ])
117+ # Get current frame index (MDAnalysis trajectory)
118+ current_frame = universe .trajectory .ts .frame
119+
120+ # Only recompute if frame has changed
121+ if self ._cached_frame == current_frame :
122+ return
123+
124+ fragments = universe .atoms .fragments
125+
126+ # Compute COMs once per frame (deterministic snapshot)
127+ coms = np .array ([frag .center_of_mass () for frag in fragments ])
128+
129+ self ._cached_fragments = fragments
130+ self ._cached_coms = coms
131+ self ._cached_dimensions = universe .dimensions [:3 ]
132+ self ._cached_frame = current_frame
123133
124134 def _get_distances (self , coms , i_coords , dimensions ):
125135 """
@@ -157,40 +167,38 @@ def get_RAD_neighbors(self, universe, mol_id):
157167 list[int]:
158168 Indices of neighboring molecules identified via the RAD method.
159169 """
160- number_molecules = len (universe .atoms .fragments )
170+ # Ensure cache corresponds to current frame
171+ self ._update_cache (universe )
172+
173+ fragments = self ._cached_fragments
174+ coms = self ._cached_coms
175+ dimensions = self ._cached_dimensions
161176
162- # Precompute COMs
163- coms = self ._get_fragment_coms (universe )
177+ number_molecules = len (fragments )
164178
165- # Central molecule position
166179 central_position = coms [mol_id ]
167180
168- # Compute distances
169- distances_array = self ._get_distances (
170- coms , central_position , universe .dimensions [:3 ]
171- )
181+ # Distances computed from same COM snapshot
182+ distances_array = self ._get_distances (coms , central_position , dimensions )
172183
173- # Prepare indices
174184 indices = np .arange (number_molecules )
175185
176- # Remove self
177186 mask = indices != mol_id
178187 filtered_indices = indices [mask ]
179188 filtered_distances = distances_array [mask ]
180189
181- # Sort by distance
182- order = np .argsort (filtered_distances )
190+ # Stable sort to avoid ordering ambiguity
191+ order = np .argsort (filtered_distances , kind = "mergesort" )
183192
184193 sorted_indices = filtered_indices [order ]
185194 sorted_distances = filtered_distances [order ]
186195
187- # RAD blocking (Numba)
188196 neighbor_indices = _rad_blocking_loop (
189197 central_position ,
190198 sorted_indices ,
191199 sorted_distances ,
192200 coms ,
193- universe . dimensions [: 3 ] ,
201+ dimensions ,
194202 )
195203
196204 return neighbor_indices
@@ -214,8 +222,10 @@ def get_grid_neighbors(self, universe, mol_id, highest_level):
214222 list[int]:
215223 Fragment indices of neighboring molecules.
216224 """
225+ fragments = universe .atoms .fragments
226+ fragment = fragments [mol_id ]
227+
217228 search_object = mda .lib .NeighborSearch .AtomNeighborSearch (universe .atoms )
218- fragment = universe .atoms .fragments [mol_id ]
219229
220230 selection_string = f"index { fragment .indices [0 ]} :{ fragment .indices [- 1 ]} "
221231 molecule_atom_group = universe .select_atoms (selection_string )
0 commit comments