Skip to content

Commit 9bebe97

Browse files
committed
perf: optimise RAD neighbour search by precomputing COMs and vectorising distances
1 parent a7c1d9a commit 9bebe97

1 file changed

Lines changed: 74 additions & 73 deletions

File tree

CodeEntropy/levels/search.py

Lines changed: 74 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,39 @@ def __init__(self):
1818
Initializes the Search class with a placeholder for the system
1919
trajectory.
2020
"""
21-
2221
self._universe = None
2322
self._mol_id = None
2423

24+
def _get_fragment_coms(self, universe):
25+
"""
26+
Precompute fragment centres of mass.
27+
28+
Args:
29+
universe: MDAnalysis universe object.
30+
31+
Returns:
32+
np.ndarray: Array of fragment COMs.
33+
"""
34+
return np.array([frag.center_of_mass() for frag in universe.atoms.fragments])
35+
36+
def _get_distances(self, coms, i_coords, dimensions):
37+
"""
38+
Function to calculate distances between a central point and all COMs.
39+
Takes periodic boundary conditions into account.
40+
41+
Args:
42+
coms: array of fragment COMs
43+
i_coords: coordinates of central molecule
44+
dimensions: simulation box dimensions
45+
46+
Returns:
47+
np.ndarray: distances to all molecules
48+
"""
49+
delta = coms - i_coords
50+
delta = np.where(delta > 0.5 * dimensions, delta - dimensions, delta)
51+
delta = np.where(delta < -0.5 * dimensions, delta + dimensions, delta)
52+
return np.sqrt((delta**2).sum(axis=1))
53+
2554
def get_RAD_neighbors(self, universe, mol_id):
2655
"""
2756
Find the neighbors of molecule with index mol_id.
@@ -36,28 +65,40 @@ def get_RAD_neighbors(self, universe, mol_id):
3665
"""
3766
number_molecules = len(universe.atoms.fragments)
3867

39-
central_position = universe.atoms.fragments[mol_id].center_of_mass()
68+
# Precompute COMs once
69+
coms = self._get_fragment_coms(universe)
4070

41-
# Find distances between molecule of interest and other molecules in the system
71+
# Central molecule position
72+
central_position = coms[mol_id]
73+
74+
# Compute all distances in one vectorised call
75+
distances_array = self._get_distances(
76+
coms, central_position, universe.dimensions[:3]
77+
)
78+
79+
# Build distance dict excluding self
4280
distances = {}
4381
for molecule_index_j in range(number_molecules):
4482
if molecule_index_j != mol_id:
45-
j_position = universe.atoms.fragments[molecule_index_j].center_of_mass()
46-
distances[molecule_index_j] = self.get_distance(
47-
j_position, central_position, universe.dimensions[:3]
48-
)
83+
distances[molecule_index_j] = distances_array[molecule_index_j]
4984

5085
# Sort distances smallest to largest
5186
sorted_dist = sorted(distances.items(), key=lambda item: item[1])
5287

5388
# Get indices of neighbors
5489
neighbor_indices = self._get_RAD_indices(
55-
central_position, sorted_dist, universe, number_molecules
90+
central_position,
91+
sorted_dist,
92+
coms,
93+
universe.dimensions[:3],
94+
number_molecules,
5695
)
5796

5897
return neighbor_indices
5998

60-
def _get_RAD_indices(self, i_coords, sorted_distances, system, number_molecules):
99+
def _get_RAD_indices(
100+
self, i_coords, sorted_distances, coms, dimensions, number_molecules
101+
):
61102
# pylint: disable=too-many-locals
62103
r"""
63104
For a given set of atom coordinates, find its RAD shell from the distance
@@ -79,43 +120,45 @@ def _get_RAD_indices(self, i_coords, sorted_distances, system, number_molecules)
79120
80121
Args:
81122
i_coords: xyz centre of mass of molecule :math:`i`
82-
sorted_indices: dict of index and distance pairs sorted by distance
83-
system: mdanalysis instance of atoms in a frame
123+
sorted_distances: list of index and distance pairs sorted by distance
124+
coms: precomputed center of mass array
125+
dimensions: system box dimensions
126+
number_molecules: total number of molecules
84127
85128
Returns:
86129
shell: list of indices of particles in the RAD shell of neighbors.
87130
"""
88-
# 1. truncate neighbor list to closest 30 united atoms and iterate
89-
# through neighbors from closest to furthest/
90131
shell = []
91132
count = -1
92133
limit = min(number_molecules - 1, 30)
134+
93135
for y in range(limit):
94136
count += 1
137+
95138
j_idx = sorted_distances[y][0]
96-
j_coords = system.atoms.fragments[j_idx].center_of_mass()
97139
r_ij = sorted_distances[y][1]
140+
j_coords = coms[j_idx]
141+
98142
blocked = False
99-
# 3. iterate through neighbors other than atom j and check if they block
100-
# it from molecule i
101-
for z in range(count): # only closer units can block
143+
144+
for z in range(count):
102145
k_idx = sorted_distances[z][0]
103-
k_coords = system.atoms.fragments[k_idx].center_of_mass()
104146
r_ik = sorted_distances[z][1]
105-
# 4. find the angle jik
106-
costheta_jik = self.get_angle(
107-
j_coords, i_coords, k_coords, system.dimensions[:3]
108-
)
147+
k_coords = coms[k_idx]
148+
149+
costheta_jik = self.get_angle(j_coords, i_coords, k_coords, dimensions)
150+
109151
if np.isnan(costheta_jik):
110152
break
111-
# 5. check if k blocks j from i
153+
112154
LHS = (1 / r_ij) ** 2
113155
RHS = ((1 / r_ik) ** 2) * costheta_jik
156+
114157
if LHS < RHS:
115158
blocked = True
116159
break
117-
# 6. if j is not blocked from i by k, then its in i's shell
118-
if blocked is False:
160+
161+
if not blocked:
119162
shell.append(j_idx)
120163

121164
return shell
@@ -125,67 +168,35 @@ def get_angle(
125168
):
126169
"""
127170
Get the angle between three atoms, taking into account periodic
128-
bondary conditions.
171+
boundary conditions.
129172
130173
b is the vertex of the angle.
131174
132-
Pairwise differences between the coordinates are used with the
133-
distances calculated as the square root of the sum of the squared
134-
x, y, and z coordinates.
135-
136175
Args:
137-
a: (3,) array of atom cooordinates
138-
b: (3,) array of atom cooordinates
139-
c: (3,) array of atom cooordinates
176+
a: (3,) array of atom coordinates
177+
b: (3,) array of atom coordinates
178+
c: (3,) array of atom coordinates
140179
dimensions: (3,) array of system box dimensions.
141180
142181
Returns:
143182
cosine_angle: float, cosine of the angle abc.
144183
"""
145-
# Differences in positions
146184
ba = np.abs(a - b)
147185
bc = np.abs(c - b)
148186
ac = np.abs(c - a)
149187

150-
# Correct for periodic boundary conditions
151188
ba = np.where(ba > 0.5 * dimensions, ba - dimensions, ba)
152189
bc = np.where(bc > 0.5 * dimensions, bc - dimensions, bc)
153190
ac = np.where(ac > 0.5 * dimensions, ac - dimensions, ac)
154191

155-
# Get distances
156192
dist_ba = np.sqrt((ba**2).sum(axis=-1))
157193
dist_bc = np.sqrt((bc**2).sum(axis=-1))
158194
dist_ac = np.sqrt((ac**2).sum(axis=-1))
159195

160-
# Trigonometry
161196
cosine_angle = (dist_ac**2 - dist_bc**2 - dist_ba**2) / (-2 * dist_bc * dist_ba)
162197

163198
return cosine_angle
164199

165-
def get_distance(self, j_position, i_position, dimensions):
166-
"""
167-
Function to calculate the distance between two points.
168-
Take periodic boundary conditions into account.
169-
170-
Args:
171-
j_position: the x, y, z coordinates of point 1
172-
i_position: the x, y, z coordinates of the other point
173-
dimensions: the dimensions of the simulation box
174-
175-
Returns:
176-
distance: float, the distance between the two points
177-
"""
178-
# Difference in positions
179-
delta = np.abs(j_position - i_position)
180-
181-
# Account for periodic boundary conditions
182-
delta = np.where(delta > 0.5 * dimensions, delta - dimensions, delta)
183-
184-
# Get distance value
185-
distance = np.sqrt((delta**2).sum(axis=-1))
186-
187-
return distance
188-
189200
def get_grid_neighbors(self, universe, mol_id, highest_level):
190201
"""
191202
Use MDAnalysis neighbor search to find neighbors.
@@ -211,30 +222,20 @@ def get_grid_neighbors(self, universe, mol_id, highest_level):
211222
molecule_atom_group = universe.select_atoms(selection_string)
212223

213224
if highest_level == "united_atom":
214-
# For united atom size molecules, use the grid search
215-
# to find neighboring atoms
216-
search_level = "A"
217225
search = mda.lib.NeighborSearch.AtomNeighborSearch.search(
218226
search_object,
219227
molecule_atom_group,
220228
radius=3.0,
221-
level=search_level,
229+
level="A",
222230
)
223-
# Make sure that the neighbors list does not include
224-
# atoms from the central molecule
225-
# neighbors = search - fragment.residues
226231
neighbors = search - molecule_atom_group
227232
else:
228-
# For larger molecules, use the grid search to find neighboring residues
229-
search_level = "R"
230233
search = mda.lib.NeighborSearch.AtomNeighborSearch.search(
231234
search_object,
232235
molecule_atom_group,
233236
radius=3.5,
234-
level=search_level,
237+
level="R",
235238
)
236-
# Make sure that the neighbors list does not include
237-
# residues from the central molecule
238239
neighbors = search - fragment.residues
239240
neighbors = neighbors.atoms
240241

0 commit comments

Comments
 (0)