Skip to content

Commit 74c9624

Browse files
committed
perf: further optimise RAD neighbour search by reducing inner-loop overhead
1 parent cdfd4ae commit 74c9624

1 file changed

Lines changed: 68 additions & 32 deletions

File tree

CodeEntropy/levels/search.py

Lines changed: 68 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,31 @@
99
from numba import njit
1010

1111

12+
@njit
13+
def _apply_pbc(vec, dimensions, half_dimensions):
14+
"""
15+
Apply minimum image convention for periodic boundary conditions.
16+
17+
Args:
18+
vec (np.ndarray):
19+
Vector to wrap.
20+
dimensions (np.ndarray):
21+
Simulation box dimensions.
22+
half_dimensions (np.ndarray):
23+
Half box lengths.
24+
25+
Returns:
26+
np.ndarray:
27+
Wrapped vector.
28+
"""
29+
for d in range(3):
30+
if vec[d] > half_dimensions[d]:
31+
vec[d] -= dimensions[d]
32+
elif vec[d] < -half_dimensions[d]:
33+
vec[d] += dimensions[d]
34+
return vec
35+
36+
1237
@njit
1338
def _rad_blocking_loop(i_coords, sorted_indices, sorted_distances, coms, dimensions):
1439
"""
@@ -38,54 +63,58 @@ def _rad_blocking_loop(i_coords, sorted_indices, sorted_distances, coms, dimensi
3863
Simulation box dimensions for periodic boundary conditions.
3964
4065
Returns:
41-
list[int]:
66+
np.ndarray:
4267
Indices of molecules that belong to the RAD neighbor shell.
4368
"""
44-
shell = []
45-
4669
n = sorted_indices.shape[0]
4770
limit = min(n, 30)
4871

72+
half_dimensions = 0.5 * dimensions
73+
74+
inv_r2 = 1.0 / (sorted_distances * sorted_distances)
75+
76+
shell = np.empty(limit, dtype=np.int64)
77+
count = 0
78+
4979
for y in range(limit):
5080
j_idx = sorted_indices[y]
5181
r_ij = sorted_distances[y]
5282
j_coords = coms[j_idx]
5383

84+
ba = j_coords - i_coords
85+
ba = _apply_pbc(ba, dimensions, half_dimensions)
86+
5487
blocked = False
5588

5689
for z in range(y):
5790
k_idx = sorted_indices[z]
5891
r_ik = sorted_distances[z]
59-
k_coords = coms[k_idx]
6092

61-
ba = np.abs(j_coords - i_coords)
62-
bc = np.abs(k_coords - i_coords)
63-
ac = np.abs(k_coords - j_coords)
93+
if r_ik > r_ij:
94+
continue
6495

65-
ba = np.where(ba > 0.5 * dimensions, ba - dimensions, ba)
66-
bc = np.where(bc > 0.5 * dimensions, bc - dimensions, bc)
67-
ac = np.where(ac > 0.5 * dimensions, ac - dimensions, ac)
96+
k_coords = coms[k_idx]
6897

69-
dist_ba = np.sqrt((ba**2).sum())
70-
dist_bc = np.sqrt((bc**2).sum())
71-
dist_ac = np.sqrt((ac**2).sum())
98+
ac = k_coords - j_coords
99+
ac = _apply_pbc(ac, dimensions, half_dimensions)
72100

73-
costheta = (dist_ac**2 - dist_bc**2 - dist_ba**2) / (-2 * dist_bc * dist_ba)
101+
dist_ac2 = (ac * ac).sum()
74102

75-
if np.isnan(costheta):
76-
break
103+
denom = -2.0 * r_ik * r_ij
104+
if denom == 0.0:
105+
continue
77106

78-
LHS = (1.0 / r_ij) ** 2
79-
RHS = ((1.0 / r_ik) ** 2) * costheta
107+
costheta = (dist_ac2 - r_ik * r_ik - r_ij * r_ij) / denom
80108

81-
if LHS < RHS:
109+
if inv_r2[y] < inv_r2[z] * costheta:
82110
blocked = True
83111
break
84112

85113
if not blocked:
86-
shell.append(j_idx)
114+
shell[count] = j_idx
115+
count += 1
87116

88-
return shell
117+
return shell[:count]
89118

90119

91120
class Search:
@@ -114,16 +143,13 @@ def _update_cache(self, universe):
114143
universe (MDAnalysis.Universe):
115144
MDAnalysis universe object containing the system.
116145
"""
117-
# Get current frame index (MDAnalysis trajectory)
118146
current_frame = universe.trajectory.ts.frame
119147

120-
# Only recompute if frame has changed
121148
if self._cached_frame == current_frame:
122149
return
123150

124151
fragments = universe.atoms.fragments
125152

126-
# Compute COMs once per frame (deterministic snapshot)
127153
coms = np.array([frag.center_of_mass() for frag in fragments])
128154

129155
self._cached_fragments = fragments
@@ -149,9 +175,22 @@ def _get_distances(self, coms, i_coords, dimensions):
149175
Distances from the central molecule to all fragments.
150176
"""
151177
delta = coms - i_coords
152-
delta = np.where(delta > 0.5 * dimensions, delta - dimensions, delta)
153-
delta = np.where(delta < -0.5 * dimensions, delta + dimensions, delta)
154-
return np.sqrt((delta**2).sum(axis=1))
178+
179+
half_dimensions = 0.5 * dimensions
180+
181+
for d in range(3):
182+
delta[:, d] = np.where(
183+
delta[:, d] > half_dimensions[d],
184+
delta[:, d] - dimensions[d],
185+
delta[:, d],
186+
)
187+
delta[:, d] = np.where(
188+
delta[:, d] < -half_dimensions[d],
189+
delta[:, d] + dimensions[d],
190+
delta[:, d],
191+
)
192+
193+
return np.sqrt((delta * delta).sum(axis=1))
155194

156195
def get_RAD_neighbors(self, universe, mol_id):
157196
"""
@@ -164,10 +203,9 @@ def get_RAD_neighbors(self, universe, mol_id):
164203
Index of the central molecule.
165204
166205
Returns:
167-
list[int]:
206+
np.ndarray:
168207
Indices of neighboring molecules identified via the RAD method.
169208
"""
170-
# Ensure cache corresponds to current frame
171209
self._update_cache(universe)
172210

173211
fragments = self._cached_fragments
@@ -178,7 +216,6 @@ def get_RAD_neighbors(self, universe, mol_id):
178216

179217
central_position = coms[mol_id]
180218

181-
# Distances computed from same COM snapshot
182219
distances_array = self._get_distances(coms, central_position, dimensions)
183220

184221
indices = np.arange(number_molecules)
@@ -187,7 +224,6 @@ def get_RAD_neighbors(self, universe, mol_id):
187224
filtered_indices = indices[mask]
188225
filtered_distances = distances_array[mask]
189226

190-
# Stable sort to avoid ordering ambiguity
191227
order = np.argsort(filtered_distances, kind="mergesort")
192228

193229
sorted_indices = filtered_indices[order]
@@ -219,7 +255,7 @@ def get_grid_neighbors(self, universe, mol_id, highest_level):
219255
Molecule level ("united_atom" or other).
220256
221257
Returns:
222-
list[int]:
258+
np.ndarray:
223259
Fragment indices of neighboring molecules.
224260
"""
225261
fragments = universe.atoms.fragments

0 commit comments

Comments
 (0)