Skip to content

Commit e3f68d0

Browse files
committed
perf: accelerate RAD neighbour search using NumPy sorting and Numba-compiled blocking loop
1 parent 9bebe97 commit e3f68d0

3 files changed

Lines changed: 142 additions & 141 deletions

File tree

CodeEntropy/levels/search.py

Lines changed: 140 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,90 @@
66

77
import MDAnalysis as mda
88
import numpy as np
9+
from numba import njit
10+
11+
12+
@njit
13+
def _rad_blocking_loop(i_coords, sorted_indices, sorted_distances, coms, dimensions):
14+
"""
15+
Perform RAD neighbor selection using a blocking criterion.
16+
17+
This is a Numba-compiled implementation of the RAD algorithm, which
18+
determines whether a molecule j is a neighbor of molecule i by checking
19+
whether any closer molecule k blocks j based on angular and distance
20+
relationships.
21+
22+
The criterion is based on:
23+
24+
(1 / r_ij)^2 > (1 / r_ik)^2 * cos(theta_jik)
25+
26+
where k blocks j if the inequality holds.
27+
28+
Args:
29+
i_coords (np.ndarray):
30+
Coordinates of the central molecule.
31+
sorted_indices (np.ndarray):
32+
Indices of molecules sorted by distance from i.
33+
sorted_distances (np.ndarray):
34+
Distances corresponding to sorted_indices.
35+
coms (np.ndarray):
36+
Precomputed center of mass coordinates for all molecules.
37+
dimensions (np.ndarray):
38+
Simulation box dimensions for periodic boundary conditions.
39+
40+
Returns:
41+
list[int]:
42+
Indices of molecules that belong to the RAD neighbor shell.
43+
"""
44+
shell = []
45+
46+
n = sorted_indices.shape[0]
47+
limit = min(n, 30)
48+
49+
for y in range(limit):
50+
j_idx = sorted_indices[y]
51+
r_ij = sorted_distances[y]
52+
j_coords = coms[j_idx]
53+
54+
blocked = False
55+
56+
for z in range(y):
57+
k_idx = sorted_indices[z]
58+
r_ik = sorted_distances[z]
59+
k_coords = coms[k_idx]
60+
61+
# Compute coordinate differences
62+
ba = np.abs(j_coords - i_coords)
63+
bc = np.abs(k_coords - i_coords)
64+
ac = np.abs(k_coords - j_coords)
65+
66+
# Apply periodic boundary conditions
67+
ba = np.where(ba > 0.5 * dimensions, ba - dimensions, ba)
68+
bc = np.where(bc > 0.5 * dimensions, bc - dimensions, bc)
69+
ac = np.where(ac > 0.5 * dimensions, ac - dimensions, ac)
70+
71+
# Compute distances
72+
dist_ba = np.sqrt((ba**2).sum())
73+
dist_bc = np.sqrt((bc**2).sum())
74+
dist_ac = np.sqrt((ac**2).sum())
75+
76+
# Cosine of angle jik
77+
costheta = (dist_ac**2 - dist_bc**2 - dist_ba**2) / (-2 * dist_bc * dist_ba)
78+
79+
if np.isnan(costheta):
80+
break
81+
82+
LHS = (1.0 / r_ij) ** 2
83+
RHS = ((1.0 / r_ik) ** 2) * costheta
84+
85+
if LHS < RHS:
86+
blocked = True
87+
break
88+
89+
if not blocked:
90+
shell.append(j_idx)
91+
92+
return shell
993

1094

1195
class Search:
@@ -15,36 +99,44 @@ class Search:
1599

16100
def __init__(self):
17101
"""
18-
Initializes the Search class with a placeholder for the system
19-
trajectory.
102+
Initialize the Search class.
103+
104+
This class currently serves as a container for neighbor search
105+
methods operating on an MDAnalysis universe.
20106
"""
21107
self._universe = None
22108
self._mol_id = None
23109

24110
def _get_fragment_coms(self, universe):
25111
"""
26-
Precompute fragment centres of mass.
112+
Precompute center of mass for each molecular fragment.
27113
28114
Args:
29-
universe: MDAnalysis universe object.
115+
universe (MDAnalysis.Universe):
116+
MDAnalysis universe object containing the system.
30117
31118
Returns:
32-
np.ndarray: Array of fragment COMs.
119+
np.ndarray:
120+
Array of shape (n_fragments, 3) containing COM coordinates.
33121
"""
34122
return np.array([frag.center_of_mass() for frag in universe.atoms.fragments])
35123

36124
def _get_distances(self, coms, i_coords, dimensions):
37125
"""
38-
Function to calculate distances between a central point and all COMs.
39-
Takes periodic boundary conditions into account.
126+
Compute distances between a central coordinate and all fragment COMs
127+
using periodic boundary conditions.
40128
41129
Args:
42-
coms: array of fragment COMs
43-
i_coords: coordinates of central molecule
44-
dimensions: simulation box dimensions
130+
coms (np.ndarray):
131+
Array of fragment center of mass coordinates.
132+
i_coords (np.ndarray):
133+
Coordinates of the reference (central) molecule.
134+
dimensions (np.ndarray):
135+
Simulation box dimensions.
45136
46137
Returns:
47-
np.ndarray: distances to all molecules
138+
np.ndarray:
139+
Distances from the central molecule to all fragments.
48140
"""
49141
delta = coms - i_coords
50142
delta = np.where(delta > 0.5 * dimensions, delta - dimensions, delta)
@@ -53,171 +145,78 @@ def _get_distances(self, coms, i_coords, dimensions):
53145

54146
def get_RAD_neighbors(self, universe, mol_id):
55147
"""
56-
Find the neighbors of molecule with index mol_id.
148+
Find RAD neighbors of a given molecule.
57149
58150
Args:
59-
universe: The MDAnalysis universe of the system.
60-
mol_id (int): the index for the central molecule.
151+
universe (MDAnalysis.Universe):
152+
The MDAnalysis universe of the system.
153+
mol_id (int):
154+
Index of the central molecule.
61155
62156
Returns:
63-
neighbor_indices (list of ints): the list of neighboring molecule
64-
indices.
157+
list[int]:
158+
Indices of neighboring molecules identified via the RAD method.
65159
"""
66160
number_molecules = len(universe.atoms.fragments)
67161

68-
# Precompute COMs once
162+
# Precompute COMs
69163
coms = self._get_fragment_coms(universe)
70164

71165
# Central molecule position
72166
central_position = coms[mol_id]
73167

74-
# Compute all distances in one vectorised call
168+
# Compute distances
75169
distances_array = self._get_distances(
76170
coms, central_position, universe.dimensions[:3]
77171
)
78172

79-
# Build distance dict excluding self
80-
distances = {}
81-
for molecule_index_j in range(number_molecules):
82-
if molecule_index_j != mol_id:
83-
distances[molecule_index_j] = distances_array[molecule_index_j]
173+
# Prepare indices
174+
indices = np.arange(number_molecules)
84175

85-
# Sort distances smallest to largest
86-
sorted_dist = sorted(distances.items(), key=lambda item: item[1])
176+
# Remove self
177+
mask = indices != mol_id
178+
filtered_indices = indices[mask]
179+
filtered_distances = distances_array[mask]
87180

88-
# Get indices of neighbors
89-
neighbor_indices = self._get_RAD_indices(
181+
# Sort by distance
182+
order = np.argsort(filtered_distances)
183+
184+
sorted_indices = filtered_indices[order]
185+
sorted_distances = filtered_distances[order]
186+
187+
# RAD blocking (Numba)
188+
neighbor_indices = _rad_blocking_loop(
90189
central_position,
91-
sorted_dist,
190+
sorted_indices,
191+
sorted_distances,
92192
coms,
93193
universe.dimensions[:3],
94-
number_molecules,
95194
)
96195

97196
return neighbor_indices
98197

99-
def _get_RAD_indices(
100-
self, i_coords, sorted_distances, coms, dimensions, number_molecules
101-
):
102-
# pylint: disable=too-many-locals
103-
r"""
104-
For a given set of atom coordinates, find its RAD shell from the distance
105-
sorted list, truncated to the closest 30 molecules.
106-
107-
This function calculates coordination shells using RAD the relative
108-
angular distance, as defined first in DOI:10.1063/1.4961439
109-
where molecules are defined as neighbors if
110-
they fulfil the following condition:
111-
112-
.. math::
113-
\Bigg(\frac{1}{r_{ij}}\Bigg)^2 >
114-
\Bigg(\frac{1}{r_{ik}}\Bigg)^2 \cos \theta_{jik}
115-
116-
For a given particle :math:`i`, neighbor :math:`j` is in its coordination
117-
shell if :math:`k` is not blocking particle :math:`j`. In this implementation
118-
of RAD, we enforce symmetry, whereby neighboring particles must be in each
119-
others coordination shells.
120-
121-
Args:
122-
i_coords: xyz centre of mass of molecule :math:`i`
123-
sorted_distances: list of index and distance pairs sorted by distance
124-
coms: precomputed center of mass array
125-
dimensions: system box dimensions
126-
number_molecules: total number of molecules
127-
128-
Returns:
129-
shell: list of indices of particles in the RAD shell of neighbors.
130-
"""
131-
shell = []
132-
count = -1
133-
limit = min(number_molecules - 1, 30)
134-
135-
for y in range(limit):
136-
count += 1
137-
138-
j_idx = sorted_distances[y][0]
139-
r_ij = sorted_distances[y][1]
140-
j_coords = coms[j_idx]
141-
142-
blocked = False
143-
144-
for z in range(count):
145-
k_idx = sorted_distances[z][0]
146-
r_ik = sorted_distances[z][1]
147-
k_coords = coms[k_idx]
148-
149-
costheta_jik = self.get_angle(j_coords, i_coords, k_coords, dimensions)
150-
151-
if np.isnan(costheta_jik):
152-
break
153-
154-
LHS = (1 / r_ij) ** 2
155-
RHS = ((1 / r_ik) ** 2) * costheta_jik
156-
157-
if LHS < RHS:
158-
blocked = True
159-
break
160-
161-
if not blocked:
162-
shell.append(j_idx)
163-
164-
return shell
165-
166-
def get_angle(
167-
self, a: np.ndarray, b: np.ndarray, c: np.ndarray, dimensions: np.ndarray
168-
):
169-
"""
170-
Get the angle between three atoms, taking into account periodic
171-
boundary conditions.
172-
173-
b is the vertex of the angle.
174-
175-
Args:
176-
a: (3,) array of atom coordinates
177-
b: (3,) array of atom coordinates
178-
c: (3,) array of atom coordinates
179-
dimensions: (3,) array of system box dimensions.
180-
181-
Returns:
182-
cosine_angle: float, cosine of the angle abc.
183-
"""
184-
ba = np.abs(a - b)
185-
bc = np.abs(c - b)
186-
ac = np.abs(c - a)
187-
188-
ba = np.where(ba > 0.5 * dimensions, ba - dimensions, ba)
189-
bc = np.where(bc > 0.5 * dimensions, bc - dimensions, bc)
190-
ac = np.where(ac > 0.5 * dimensions, ac - dimensions, ac)
191-
192-
dist_ba = np.sqrt((ba**2).sum(axis=-1))
193-
dist_bc = np.sqrt((bc**2).sum(axis=-1))
194-
dist_ac = np.sqrt((ac**2).sum(axis=-1))
195-
196-
cosine_angle = (dist_ac**2 - dist_bc**2 - dist_ba**2) / (-2 * dist_bc * dist_ba)
197-
198-
return cosine_angle
199-
200198
def get_grid_neighbors(self, universe, mol_id, highest_level):
201199
"""
202-
Use MDAnalysis neighbor search to find neighbors.
203-
204-
For molecules with just one united atom, use the "A" search level to
205-
find neighboring atoms. For larger molecules use the "R" search level
206-
to find neighboring residues.
200+
Find neighbors using MDAnalysis grid-based neighbor search.
207201
208-
The atoms/residues of the molecule of interest are removed from the
209-
neighbor list.
202+
For small molecules (united_atom), atom-level search is used.
203+
For larger molecules, residue-level search is used.
210204
211205
Args:
212-
universe: MDAnalysis universe object for system.
213-
mol_id: int, the index for the molecule of interest.
214-
highest_level: str, molecule level.
206+
universe (MDAnalysis.Universe):
207+
MDAnalysis universe object for the system.
208+
mol_id (int):
209+
Index of the molecule of interest.
210+
highest_level (str):
211+
Molecule level ("united_atom" or other).
215212
216213
Returns:
217-
neighbors: MDAnalysis atomgroup of the neighboring particles.
214+
list[int]:
215+
Fragment indices of neighboring molecules.
218216
"""
219217
search_object = mda.lib.NeighborSearch.AtomNeighborSearch(universe.atoms)
220218
fragment = universe.atoms.fragments[mol_id]
219+
221220
selection_string = f"index {fragment.indices[0]}:{fragment.indices[-1]}"
222221
molecule_atom_group = universe.select_atoms(selection_string)
223222

conda-recipe/meta.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ requirements:
3939
- distributed >=2026.1.2,<2026.2.0
4040
- dask-jobqueue >=0.9,<0.10
4141
- pytest-xdist >=3.8, <3.9
42+
- numba >=0.64.0, <0.7
4243

4344
test:
4445
imports:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ dependencies = [
5050
"waterEntropy>=2,<2.3",
5151
"requests>=2.32,<3.0",
5252
"rdkit>=2025.9.5",
53+
"numba>=0.65.0,<0.7",
5354
]
5455

5556
[project.urls]

0 commit comments

Comments
 (0)