Skip to content

Commit 1597705

Browse files
committed
perf: make RAD neighbour search frame-safe by caching COMs per frame and enforcing deterministic ordering to restore identical results
1 parent e3f68d0 commit 1597705

1 file changed

Lines changed: 40 additions & 30 deletions

File tree

CodeEntropy/levels/search.py

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -58,22 +58,18 @@ def _rad_blocking_loop(i_coords, sorted_indices, sorted_distances, coms, dimensi
5858
r_ik = sorted_distances[z]
5959
k_coords = coms[k_idx]
6060

61-
# Compute coordinate differences
6261
ba = np.abs(j_coords - i_coords)
6362
bc = np.abs(k_coords - i_coords)
6463
ac = np.abs(k_coords - j_coords)
6564

66-
# Apply periodic boundary conditions
6765
ba = np.where(ba > 0.5 * dimensions, ba - dimensions, ba)
6866
bc = np.where(bc > 0.5 * dimensions, bc - dimensions, bc)
6967
ac = np.where(ac > 0.5 * dimensions, ac - dimensions, ac)
7068

71-
# Compute distances
7269
dist_ba = np.sqrt((ba**2).sum())
7370
dist_bc = np.sqrt((bc**2).sum())
7471
dist_ac = np.sqrt((ac**2).sum())
7572

76-
# Cosine of angle jik
7773
costheta = (dist_ac**2 - dist_bc**2 - dist_ba**2) / (-2 * dist_bc * dist_ba)
7874

7975
if np.isnan(costheta):
@@ -101,25 +97,39 @@ def __init__(self):
10197
"""
10298
Initialize the Search class.
10399
104-
This class currently serves as a container for neighbor search
105-
methods operating on an MDAnalysis universe.
100+
This class includes frame-safe caching of fragment COMs and
101+
system dimensions to avoid recomputation while preserving
102+
identical results to the original implementation.
106103
"""
107-
self._universe = None
108-
self._mol_id = None
104+
self._cached_frame = None
105+
self._cached_fragments = None
106+
self._cached_coms = None
107+
self._cached_dimensions = None
109108

110-
def _get_fragment_coms(self, universe):
109+
def _update_cache(self, universe):
111110
"""
112-
Precompute center of mass for each molecular fragment.
111+
Update cached MDAnalysis data if the simulation frame has changed.
113112
114113
Args:
115114
universe (MDAnalysis.Universe):
116115
MDAnalysis universe object containing the system.
117-
118-
Returns:
119-
np.ndarray:
120-
Array of shape (n_fragments, 3) containing COM coordinates.
121116
"""
122-
return np.array([frag.center_of_mass() for frag in universe.atoms.fragments])
117+
# Get current frame index (MDAnalysis trajectory)
118+
current_frame = universe.trajectory.ts.frame
119+
120+
# Only recompute if frame has changed
121+
if self._cached_frame == current_frame:
122+
return
123+
124+
fragments = universe.atoms.fragments
125+
126+
# Compute COMs once per frame (deterministic snapshot)
127+
coms = np.array([frag.center_of_mass() for frag in fragments])
128+
129+
self._cached_fragments = fragments
130+
self._cached_coms = coms
131+
self._cached_dimensions = universe.dimensions[:3]
132+
self._cached_frame = current_frame
123133

124134
def _get_distances(self, coms, i_coords, dimensions):
125135
"""
@@ -157,40 +167,38 @@ def get_RAD_neighbors(self, universe, mol_id):
157167
list[int]:
158168
Indices of neighboring molecules identified via the RAD method.
159169
"""
160-
number_molecules = len(universe.atoms.fragments)
170+
# Ensure cache corresponds to current frame
171+
self._update_cache(universe)
172+
173+
fragments = self._cached_fragments
174+
coms = self._cached_coms
175+
dimensions = self._cached_dimensions
161176

162-
# Precompute COMs
163-
coms = self._get_fragment_coms(universe)
177+
number_molecules = len(fragments)
164178

165-
# Central molecule position
166179
central_position = coms[mol_id]
167180

168-
# Compute distances
169-
distances_array = self._get_distances(
170-
coms, central_position, universe.dimensions[:3]
171-
)
181+
# Distances computed from same COM snapshot
182+
distances_array = self._get_distances(coms, central_position, dimensions)
172183

173-
# Prepare indices
174184
indices = np.arange(number_molecules)
175185

176-
# Remove self
177186
mask = indices != mol_id
178187
filtered_indices = indices[mask]
179188
filtered_distances = distances_array[mask]
180189

181-
# Sort by distance
182-
order = np.argsort(filtered_distances)
190+
# Stable sort to avoid ordering ambiguity
191+
order = np.argsort(filtered_distances, kind="mergesort")
183192

184193
sorted_indices = filtered_indices[order]
185194
sorted_distances = filtered_distances[order]
186195

187-
# RAD blocking (Numba)
188196
neighbor_indices = _rad_blocking_loop(
189197
central_position,
190198
sorted_indices,
191199
sorted_distances,
192200
coms,
193-
universe.dimensions[:3],
201+
dimensions,
194202
)
195203

196204
return neighbor_indices
@@ -214,8 +222,10 @@ def get_grid_neighbors(self, universe, mol_id, highest_level):
214222
list[int]:
215223
Fragment indices of neighboring molecules.
216224
"""
225+
fragments = universe.atoms.fragments
226+
fragment = fragments[mol_id]
227+
217228
search_object = mda.lib.NeighborSearch.AtomNeighborSearch(universe.atoms)
218-
fragment = universe.atoms.fragments[mol_id]
219229

220230
selection_string = f"index {fragment.indices[0]}:{fragment.indices[-1]}"
221231
molecule_atom_group = universe.select_atoms(selection_string)

0 commit comments

Comments
 (0)