@@ -18,10 +18,39 @@ def __init__(self):
1818 Initializes the Search class with a placeholder for the system
1919 trajectory.
2020 """
21-
2221 self ._universe = None
2322 self ._mol_id = None
2423
24+ def _get_fragment_coms (self , universe ):
25+ """
26+ Precompute fragment centres of mass.
27+
28+ Args:
29+ universe: MDAnalysis universe object.
30+
31+ Returns:
32+ np.ndarray: Array of fragment COMs.
33+ """
34+ return np .array ([frag .center_of_mass () for frag in universe .atoms .fragments ])
35+
36+ def _get_distances (self , coms , i_coords , dimensions ):
37+ """
38+ Function to calculate distances between a central point and all COMs.
39+ Takes periodic boundary conditions into account.
40+
41+ Args:
42+ coms: array of fragment COMs
43+ i_coords: coordinates of central molecule
44+ dimensions: simulation box dimensions
45+
46+ Returns:
47+ np.ndarray: distances to all molecules
48+ """
49+ delta = coms - i_coords
50+ delta = np .where (delta > 0.5 * dimensions , delta - dimensions , delta )
51+ delta = np .where (delta < - 0.5 * dimensions , delta + dimensions , delta )
52+ return np .sqrt ((delta ** 2 ).sum (axis = 1 ))
53+
2554 def get_RAD_neighbors (self , universe , mol_id ):
2655 """
2756 Find the neighbors of molecule with index mol_id.
@@ -36,28 +65,40 @@ def get_RAD_neighbors(self, universe, mol_id):
3665 """
3766 number_molecules = len (universe .atoms .fragments )
3867
39- central_position = universe .atoms .fragments [mol_id ].center_of_mass ()
68+ # Precompute COMs once
69+ coms = self ._get_fragment_coms (universe )
4070
41- # Find distances between molecule of interest and other molecules in the system
71+ # Central molecule position
72+ central_position = coms [mol_id ]
73+
74+ # Compute all distances in one vectorised call
75+ distances_array = self ._get_distances (
76+ coms , central_position , universe .dimensions [:3 ]
77+ )
78+
79+ # Build distance dict excluding self
4280 distances = {}
4381 for molecule_index_j in range (number_molecules ):
4482 if molecule_index_j != mol_id :
45- j_position = universe .atoms .fragments [molecule_index_j ].center_of_mass ()
46- distances [molecule_index_j ] = self .get_distance (
47- j_position , central_position , universe .dimensions [:3 ]
48- )
83+ distances [molecule_index_j ] = distances_array [molecule_index_j ]
4984
5085 # Sort distances smallest to largest
5186 sorted_dist = sorted (distances .items (), key = lambda item : item [1 ])
5287
5388 # Get indices of neighbors
5489 neighbor_indices = self ._get_RAD_indices (
55- central_position , sorted_dist , universe , number_molecules
90+ central_position ,
91+ sorted_dist ,
92+ coms ,
93+ universe .dimensions [:3 ],
94+ number_molecules ,
5695 )
5796
5897 return neighbor_indices
5998
60- def _get_RAD_indices (self , i_coords , sorted_distances , system , number_molecules ):
99+ def _get_RAD_indices (
100+ self , i_coords , sorted_distances , coms , dimensions , number_molecules
101+ ):
61102 # pylint: disable=too-many-locals
62103 r"""
63104 For a given set of atom coordinates, find its RAD shell from the distance
@@ -79,43 +120,45 @@ def _get_RAD_indices(self, i_coords, sorted_distances, system, number_molecules)
79120
80121 Args:
81122 i_coords: xyz centre of mass of molecule :math:`i`
82- sorted_indices: dict of index and distance pairs sorted by distance
83- system: mdanalysis instance of atoms in a frame
123+ sorted_distances: list of index and distance pairs sorted by distance
124+ coms: precomputed center of mass array
125+ dimensions: system box dimensions
126+ number_molecules: total number of molecules
84127
85128 Returns:
86129 shell: list of indices of particles in the RAD shell of neighbors.
87130 """
88- # 1. truncate neighbor list to closest 30 united atoms and iterate
89- # through neighbors from closest to furthest/
90131 shell = []
91132 count = - 1
92133 limit = min (number_molecules - 1 , 30 )
134+
93135 for y in range (limit ):
94136 count += 1
137+
95138 j_idx = sorted_distances [y ][0 ]
96- j_coords = system .atoms .fragments [j_idx ].center_of_mass ()
97139 r_ij = sorted_distances [y ][1 ]
140+ j_coords = coms [j_idx ]
141+
98142 blocked = False
99- # 3. iterate through neighbors other than atom j and check if they block
100- # it from molecule i
101- for z in range (count ): # only closer units can block
143+
144+ for z in range (count ):
102145 k_idx = sorted_distances [z ][0 ]
103- k_coords = system .atoms .fragments [k_idx ].center_of_mass ()
104146 r_ik = sorted_distances [z ][1 ]
105- # 4. find the angle jik
106- costheta_jik = self . get_angle (
107- j_coords , i_coords , k_coords , system . dimensions [: 3 ]
108- )
147+ k_coords = coms [ k_idx ]
148+
149+ costheta_jik = self . get_angle ( j_coords , i_coords , k_coords , dimensions )
150+
109151 if np .isnan (costheta_jik ):
110152 break
111- # 5. check if k blocks j from i
153+
112154 LHS = (1 / r_ij ) ** 2
113155 RHS = ((1 / r_ik ) ** 2 ) * costheta_jik
156+
114157 if LHS < RHS :
115158 blocked = True
116159 break
117- # 6. if j is not blocked from i by k, then its in i's shell
118- if blocked is False :
160+
161+ if not blocked :
119162 shell .append (j_idx )
120163
121164 return shell
@@ -125,67 +168,35 @@ def get_angle(
125168 ):
126169 """
127170 Get the angle between three atoms, taking into account periodic
128- bondary conditions.
171+ boundary conditions.
129172
130173 b is the vertex of the angle.
131174
132- Pairwise differences between the coordinates are used with the
133- distances calculated as the square root of the sum of the squared
134- x, y, and z coordinates.
135-
136175 Args:
137- a: (3,) array of atom cooordinates
138- b: (3,) array of atom cooordinates
139- c: (3,) array of atom cooordinates
176+ a: (3,) array of atom coordinates
177+ b: (3,) array of atom coordinates
178+ c: (3,) array of atom coordinates
140179 dimensions: (3,) array of system box dimensions.
141180
142181 Returns:
143182 cosine_angle: float, cosine of the angle abc.
144183 """
145- # Differences in positions
146184 ba = np .abs (a - b )
147185 bc = np .abs (c - b )
148186 ac = np .abs (c - a )
149187
150- # Correct for periodic boundary conditions
151188 ba = np .where (ba > 0.5 * dimensions , ba - dimensions , ba )
152189 bc = np .where (bc > 0.5 * dimensions , bc - dimensions , bc )
153190 ac = np .where (ac > 0.5 * dimensions , ac - dimensions , ac )
154191
155- # Get distances
156192 dist_ba = np .sqrt ((ba ** 2 ).sum (axis = - 1 ))
157193 dist_bc = np .sqrt ((bc ** 2 ).sum (axis = - 1 ))
158194 dist_ac = np .sqrt ((ac ** 2 ).sum (axis = - 1 ))
159195
160- # Trigonometry
161196 cosine_angle = (dist_ac ** 2 - dist_bc ** 2 - dist_ba ** 2 ) / (- 2 * dist_bc * dist_ba )
162197
163198 return cosine_angle
164199
165- def get_distance (self , j_position , i_position , dimensions ):
166- """
167- Function to calculate the distance between two points.
168- Take periodic boundary conditions into account.
169-
170- Args:
171- j_position: the x, y, z coordinates of point 1
172- i_position: the x, y, z coordinates of the other point
173- dimensions: the dimensions of the simulation box
174-
175- Returns:
176- distance: float, the distance between the two points
177- """
178- # Difference in positions
179- delta = np .abs (j_position - i_position )
180-
181- # Account for periodic boundary conditions
182- delta = np .where (delta > 0.5 * dimensions , delta - dimensions , delta )
183-
184- # Get distance value
185- distance = np .sqrt ((delta ** 2 ).sum (axis = - 1 ))
186-
187- return distance
188-
189200 def get_grid_neighbors (self , universe , mol_id , highest_level ):
190201 """
191202 Use MDAnalysis neighbor search to find neighbors.
@@ -211,30 +222,20 @@ def get_grid_neighbors(self, universe, mol_id, highest_level):
211222 molecule_atom_group = universe .select_atoms (selection_string )
212223
213224 if highest_level == "united_atom" :
214- # For united atom size molecules, use the grid search
215- # to find neighboring atoms
216- search_level = "A"
217225 search = mda .lib .NeighborSearch .AtomNeighborSearch .search (
218226 search_object ,
219227 molecule_atom_group ,
220228 radius = 3.0 ,
221- level = search_level ,
229+ level = "A" ,
222230 )
223- # Make sure that the neighbors list does not include
224- # atoms from the central molecule
225- # neighbors = search - fragment.residues
226231 neighbors = search - molecule_atom_group
227232 else :
228- # For larger molecules, use the grid search to find neighboring residues
229- search_level = "R"
230233 search = mda .lib .NeighborSearch .AtomNeighborSearch .search (
231234 search_object ,
232235 molecule_atom_group ,
233236 radius = 3.5 ,
234- level = search_level ,
237+ level = "R" ,
235238 )
236- # Make sure that the neighbors list does not include
237- # residues from the central molecule
238239 neighbors = search - fragment .residues
239240 neighbors = neighbors .atoms
240241
0 commit comments