Skip to content

Commit b899373

Browse files
authored
Merge pull request #75 from avmarchenko/master
Adds in basic support for nearest neighbor searches
2 parents b853688 + d5b010c commit b899373

4 files changed

Lines changed: 139 additions & 40 deletions

File tree

exatomic/algorithms/distance.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,46 @@ def periodic_pdist_euc_dxyz_idx(ux, uy, uz, rx, ry, rz, idxs, tol=10**-8):
101101
return dx, dy, dz, dr, idxi, idxj, px, py, pz
102102

103103

104+
def _compute(cx, cy, cz, rx, ry, rz, ox, oy, oz):
105+
"""
106+
"""
107+
l = [-1, 0, 1]
108+
m = len(cx)
109+
dx = np.empty((m, ), dtype=np.float64)
110+
dy = np.empty((m, ), dtype=np.float64)
111+
dz = np.empty((m, ), dtype=np.float64)
112+
px = np.empty((27, ), dtype=np.float64)
113+
py = np.empty((27, ), dtype=np.float64)
114+
pz = np.empty((27, ), dtype=np.float64)
115+
pr = np.empty((27, ), dtype=np.float64)
116+
h = 0
117+
for i in range(m):
118+
cxi = cx[i]
119+
cyi = cy[i]
120+
czi = cz[i]
121+
hh = 0
122+
for ii in l:
123+
for jj in l:
124+
for kk in l:
125+
sx = ii*rx
126+
sy = jj*ry
127+
sz = kk*rz
128+
xx = cxi + sx
129+
yy = cyi + sy
130+
zz = czi + sz
131+
pr[hh] = (ox - xx)**2 + (oy - yy)**2 + (oz - zz)**2
132+
px[hh] = sx
133+
py[hh] = sy
134+
pz[hh] = sz
135+
hh += 1
136+
hh = np.argmin(pr)
137+
dx[h] = px[hh]
138+
dy[h] = py[hh]
139+
dz[h] = pz[hh]
140+
h += 1
141+
return dx, dy, dz
142+
143+
104144
if config['dynamic']['numba'] == 'true':
105145
from numba import jit, vectorize
106146
from exa.math.vector.cartesian import magnitude_xyz
@@ -109,3 +149,4 @@ def periodic_pdist_euc_dxyz_idx(ux, uy, uz, rx, ry, rz, idxs, tol=10**-8):
109149
minimal_image_counts = jit(nopython=True, cache=True, nogil=True)(minimal_image_counts)
110150
minimal_image = vectorize(types3, nopython=True)(minimal_image)
111151
periodic_pdist_euc_dxyz_idx = jit(nopython=True, cache=True, nogil=True)(periodic_pdist_euc_dxyz_idx)
152+
_compute = jit(nopython=True, cache=True, nogil=True)(_compute)

exatomic/algorithms/neighbors.py

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
and :func:`~exatomic.molecule.Molecule.classify`.
1515
"""
1616
import numpy as np
17-
from exatomic.container import Universe
17+
import pandas as pd
18+
from exatomic.algorithms.distance import _compute
1819

1920

2021
def nearest_molecules(universe, n, sources, restrictions=None, how='atom',
@@ -46,13 +47,17 @@ def nearest_molecules(universe, n, sources, restrictions=None, how='atom',
4647
Returns:
4748
unis (dict): Dictionary of number of neighbors keys, universe values
4849
"""
49-
source_atoms, other_atoms, source_molecules, other_molecules, n = _slice_atoms_molecules(universe, sources, restrictions, how)
50+
source_atoms, other_atoms, source_molecules, other_molecules, n = _slice_atoms_molecules(universe, sources, restrictions, n)
5051
ordered_molecules, ordered_twos = _compute_neighbors_by_atom(universe, source_atoms, other_atoms, source_molecules)
51-
unis = {nn: _build_universe(universe, ordered_molecules, ordered_twos, nn) for nn in n}
52-
53-
54-
55-
52+
unis = {}
53+
if free_boundary == True:
54+
for nn in n:
55+
unis[nn] = _build_free_universe(universe, ordered_molecules,
56+
ordered_twos, nn, source_atoms,
57+
source_molecules)
58+
else:
59+
raise NotImplementedError()
60+
return unis
5661

5762

5863
def _slice_atoms_molecules(universe, sources, restrictions, n):
@@ -66,7 +71,7 @@ def _slice_atoms_molecules(universe, sources, restrictions, n):
6671
sources = [sources]
6772
if not isinstance(restrictions, list) and restrictions is not None:
6873
restrictions = [restrictions]
69-
if not isinstance(n, list):
74+
if isinstance(n, (int, np.int32, np.int64)):
7075
n = [n]
7176
symbols = universe.atom['symbol'].unique()
7277
classification = universe.molecule['classification'].unique()
@@ -135,12 +140,67 @@ def _compute_neighbors_by_com(universe, source_molecules, other_molecules):
135140
raise NotImplementedError()
136141

137142

143+
def _build_free_universe(universe, ordered_molecules, ordered_twos, n,
144+
source_atoms, source_molecules):
145+
"""
146+
"""
147+
molecule = np.concatenate([mcules[:n] for mcules in ordered_molecules])
148+
molecule = np.concatenate((molecule, source_molecules.index.tolist()))
149+
molecule = universe.molecule[universe.molecule.index.isin(molecule)].copy()
150+
atom = universe.atom[universe.atom['molecule'].isin(molecule.index)].copy()
151+
atom_two = universe.atom_two[(universe.atom_two['atom0'].isin(atom.index) &
152+
universe.atom_two['atom1'].isin(atom.index))].copy()
153+
frame = universe.frame[universe.frame.index.isin(atom['frame'])].copy()
154+
frame['periodic'] = False
155+
uni = universe.__class__(atom=atom, molecule=molecule, frame=frame, atom_two=atom_two)
156+
if universe.frame.is_periodic():
157+
uni.atom.update(universe.visual_atom)
158+
uni.compute_molecule_com()
159+
uni.atom._revert_categories()
160+
mapper = uni.atom.drop_duplicates('molecule').set_index('molecule')['frame']
161+
uni.atom._set_categories()
162+
uni.molecule['frame'] = uni.molecule.index.map(lambda x: mapper[x])
163+
sources = source_atoms.groupby('frame')
164+
groups = uni.molecule.groupby('frame')
165+
n = groups.ngroups
166+
dx = np.empty((n, ), dtype=np.ndarray)
167+
dy = np.empty((n, ), dtype=np.ndarray)
168+
dz = np.empty((n, ), dtype=np.ndarray)
169+
index = np.empty((n, ), dtype=np.ndarray)
170+
for i, (frame, group) in enumerate(groups):
171+
cx = group['cx'].values
172+
cy = group['cy'].values
173+
cz = group['cz'].values
174+
ccx, ccy, ccz = sources.get_group(frame)[['x', 'y', 'z']].mean().values
175+
# ccx, ccy, ccz = mcules.ix[mcules['classification'] == 'solute', ['cx', 'cy', 'cz']].values[0]
176+
rx, ry, rz = uni.frame.ix[frame, ['rx', 'ry', 'rz']].values
177+
dxf, dyf, dzf = _compute(cx, cy, cz, rx, ry, rz, ccx, ccy, ccz)
178+
dx[i] = dxf
179+
dy[i] = dyf
180+
dz[i] = dzf
181+
index[i] = group.index.values
182+
del uni.molecule['frame']
183+
dx = np.concatenate(dx)
184+
dy = np.concatenate(dy)
185+
dz = np.concatenate(dz)
186+
index = np.concatenate(index)
187+
df = pd.DataFrame.from_dict({'x': dx, 'y': dy, 'z': dz, 'molecule': index})
188+
df.set_index('molecule', inplace=True)
189+
for molecule in df.index:
190+
dx, dy, dz = df.ix[molecule].values
191+
uni.atom.ix[uni.atom['molecule'] == molecule, 'x'] += dx
192+
uni.atom.ix[uni.atom['molecule'] == molecule, 'y'] += dy
193+
uni.atom.ix[uni.atom['molecule'] == molecule, 'z'] += dz
194+
return uni
195+
196+
138197
def _build_universe(universe, ordered_molecules, ordered_twos, n):
139198
"""
140199
"""
200+
raise NotImplementedError()
141201
# TODO CONVERT TO A GENERIC AND COMPLETE SLICER
142-
molecules = np.concatenate([m[:n] for m in ordered_molecules.values])
143-
twos = np.concatenate([t[:n] for t in ordered_twos.values])
202+
molecules = np.concatenate([m[:n] for m in ordered_molecules])
203+
twos = np.concatenate([t[:n] for t in ordered_twos])
144204
atom = universe.atom[universe.atom['molecule'].isin(molecules)].copy().sort_index()
145205
two = universe.atom_two[universe.atom_two['atom0'].isin(atom.index) &
146206
universe.atom_two['atom1'].isin(atom.index)].copy().sort_index()

exatomic/container.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@
2121
from exatomic.atom import Atom, UnitAtom, ProjectedAtom, VisualAtom
2222
from exatomic.two import (AtomTwo, MoleculeTwo, compute_atom_two,
2323
compute_bond_count, compute_molecule_two)
24-
from exatomic.molecule import Molecule, compute_molecule, compute_com
24+
from exatomic.molecule import (Molecule, compute_molecule, compute_molecule_com,
25+
compute_molecule_count)
2526
from exatomic.widget import UniverseWidget
2627
from exatomic.field import AtomicField
2728
from exatomic.orbital import Orbital, MOMatrix, DensityMatrix
2829
from exatomic.basis import (SphericalGTFOrder, CartesianGTFOrder, Overlap,
29-
BasisSetSummary, GaussianBasisSet, BasisSetOrder,
30-
Primitive)
30+
BasisSetSummary, GaussianBasisSet, BasisSetOrder)
3131

3232

3333
class Meta(TypedMeta):
@@ -49,7 +49,6 @@ class Meta(TypedMeta):
4949
orbital = Orbital
5050
overlap = Overlap
5151
momatrix = MOMatrix
52-
primitive = Primitive
5352
density = DensityMatrix
5453
basis_set_order = BasisSetOrder
5554
basis_set_summary = BasisSetSummary
@@ -133,15 +132,19 @@ def compute_molecule(self):
133132
self.molecule = compute_molecule(self)
134133

135134
def compute_molecule_com(self):
136-
cx, cy, cz = compute_com(self)
135+
cx, cy, cz = compute_molecule_com(self)
137136
self.molecule['cx'] = cx
138137
self.molecule['cy'] = cy
139138
self.molecule['cz'] = cz
140139

141140
def compute_atom_count(self):
142-
"""Compute the atom count for each frame."""
141+
"""Compute number of atoms per frame."""
143142
self.frame['atom_count'] = self.atom.grouped().size()
144143

144+
def compute_molecule_count(self):
145+
"""Compute number of molecules per frame."""
146+
self.frame['molecule_count'] = compute_molecule_count(self)
147+
145148
def _custom_traits(self):
146149
"""
147150
Build traits depending on multiple dataframes.
@@ -153,20 +156,6 @@ def _custom_traits(self):
153156
traits.update(self.atom_two._bond_traits(mapper))
154157
return traits
155158

156-
@classmethod
157-
def from_small_molecule_data(cls, center=None, ligand=None, distance=None, geometry=None,
158-
offset=None, plane=None, axis=None, domains=None, unit='A'):
159-
'''
160-
Build a universe from small molecule data
161-
162-
See
163-
exatomic.algorithms.geometry.make_small_molecule
164-
'''
165-
return cls(atom=Atom.from_small_molecule_data(center=center, ligand=ligand,
166-
distance=distance, geometry=geometry,
167-
offset=offset, plane=plane, axis=axis,
168-
domains=domains, unit=unit))
169-
170159
def __len__(self):
171160
return len(self.frame)
172161

exatomic/molecule.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -140,28 +140,37 @@ def compute_molecule(universe):
140140
def compute_molecule_count(universe):
141141
"""
142142
"""
143+
if 'molecule' not in universe.atom.columns:
144+
universe.compute_molecule()
145+
universe.atom._revert_categories()
143146
mapper = universe.atom.drop_duplicates('molecule').set_index('molecule')['frame']
147+
universe.atom._set_categories()
144148
universe.molecule['frame'] = universe.molecule.index.map(lambda x: mapper[x])
145149
molecule_count = universe.molecule.groupby('frame').size()
146150
del universe.molecule['frame']
147151
return molecule_count
148152

149153

150-
def compute_com(universe):
154+
def compute_molecule_com(universe):
151155
"""
152156
Compute molecules' centers of mass.
153157
"""
154158
if 'molecule' not in universe.atom.columns:
155159
universe.compute_molecule()
156-
xyz = universe.atom[['x', 'y', 'z', 'molecule']].copy()
157-
xyz['mass'] = universe.atom.get_element_masses()
158-
xyz.update(u.visual_atom)
159-
xyz['xm'] = xyz['x'].mul(xyz['mass'])
160-
xyz['ym'] = xyz['y'].mul(xyz['mass'])
161-
xyz['zm'] = xyz['z'].mul(xyz['mass'])
162-
xyz['rm'] = xyz['xm'].add(xyz['ym']).add(xyz['zm'])
163-
grps = xyz.groupby('molecule')
164-
sums = grps.sum()
160+
mass = universe.atom.get_element_masses()
161+
if universe.frame.is_periodic():
162+
xyz = universe.atom[['x', 'y', 'z']].copy()
163+
xyz.update(u.visual_atom)
164+
else:
165+
xyz = universe.atom[['x', 'y', 'z']]
166+
xm = xyz['x'].mul(mass)
167+
ym = xyz['y'].mul(mass)
168+
zm = xyz['z'].mul(mass)
169+
rm = xm.add(ym).add(zm)
170+
df = pd.DataFrame.from_dict({'xm': xm, 'ym': ym, 'zm': zm, 'mass': mass,
171+
'molecule': universe.atom['molecule']})
172+
groups = df.groupby('molecule')
173+
sums = groups.sum()
165174
cx = sums['xm'].div(sums['mass'])
166175
cy = sums['ym'].div(sums['mass'])
167176
cz = sums['zm'].div(sums['mass'])

0 commit comments

Comments
 (0)