@@ -48,7 +48,12 @@ def nearest_molecules(universe, n, sources, restrictions=None, how='atom',
4848 unis (dict): Dictionary of number of neighbors keys, universe values
4949 """
5050 source_atoms , other_atoms , source_molecules , other_molecules , n = _slice_atoms_molecules (universe , sources , restrictions , n )
51+ print (source_atoms .shape )
52+ print (other_atoms .shape )
53+ print (source_molecules .shape )
54+ print (other_molecules .shape )
5155 ordered_molecules , ordered_twos = _compute_neighbors_by_atom (universe , source_atoms , other_atoms , source_molecules )
56+ print (ordered_molecules .shape )
5257 unis = {}
5358 if free_boundary == True :
5459 for nn in n :
@@ -65,33 +70,50 @@ def _slice_atoms_molecules(universe, sources, restrictions, n):
6570 Initial check of the unvierse data and argument types and creation of atom
6671 and molecule table slices.
6772 """
68- if 'classification' not in universe .molecule .columns and any (len (source ) > 3 for source in sources ):
69- raise KeyErrror ("Column 'classification' not in the molecule table, please classify molecules or select by symbols only." )
7073 if not isinstance (sources , list ):
7174 sources = [sources ]
7275 if not isinstance (restrictions , list ) and restrictions is not None :
7376 restrictions = [restrictions ]
7477 if isinstance (n , (int , np .int32 , np .int64 )):
7578 n = [n ]
79+ labels = universe .atom .get_atom_labels ()
80+ universe .atom ['label' ] = labels
81+ labels = labels .unique ()
7682 symbols = universe .atom ['symbol' ].unique ()
7783 classification = universe .molecule ['classification' ].unique ()
78- if all (source in symbols for source in sources ):
84+ if all (source in labels for source in sources ):
85+ print ("all labels" )
86+ source_atoms = universe .atom [universe .atom ['label' ].isin (sources )]
87+ mdx = source_atoms ['molecule' ].astype (np .int64 )
88+ source_molecules = universe .molecule [universe .molecule .index .isin (mdx )]
89+ elif all (source in symbols for source in sources ):
90+ print ("all symbols" )
7991 source_atoms = universe .atom [universe .atom ['symbol' ].isin (sources )]
8092 mdx = source_atoms ['molecule' ].astype (np .int64 )
8193 source_molecules = universe .molecule [universe .molecule .index .isin (mdx )]
8294 elif all (source in classification for source in sources ):
95+ print ("all mols" )
8396 source_molecules = universe .molecule [universe .molecule ['classification' ].isin (sources )]
8497 source_atoms = universe .atom [universe .atom ['molecule' ].isin (source_molecules .index )]
8598 else :
99+ print ("all other" )
86100 classif = [source for source in sources if source in classification ]
87101 syms = [source for source in sources if source in symbols ]
102+ lbls = [source for source in sources if source in labels ]
88103 source_molecules = universe .molecule [universe .molecule ['classification' ].isin (classif )]
89104 source_atoms = universe .atom [universe .atom ['molecule' ].isin (source_molecules .index )]
90- source_atoms = source_atoms [source_atoms ['symbol' ].isin (syms )]
105+ if len (syms ) > 0 :
106+ source_atoms = source_atoms [source_atoms ['symbol' ].isin (syms )]
107+ if len (lbls ) > 0 :
108+ source_atoms = source_atoms [source_atoms ['label' ].isin (lbls )]
91109 other_molecules = universe .molecule [~ universe .molecule .index .isin (source_molecules .index )]
92110 other_atoms = universe .atom [~ universe .atom .index .isin (source_atoms .index )]
93111 if restrictions is not None :
94- if all (other in symbols for other in restrictions ):
112+ if all (other in labels for other in restrictions ):
113+ other_atoms = other_atoms [other_atoms ['label' ].isin (restrictions )]
114+ mdx = other_atoms ['molecule' ].astype (np .int64 )
115+ other_molecules = other_molecules [other_molecules .index .isin (mdx )]
116+ elif all (other in symbols for other in restrictions ):
95117 other_atoms = other_atoms [other_atoms ['symbol' ].isin (restrictions )]
96118 mdx = other_atoms ['molecule' ].astype (np .int64 )
97119 other_molecules = other_molecules [other_molecules .index .isin (mdx )]
@@ -101,9 +123,14 @@ def _slice_atoms_molecules(universe, sources, restrictions, n):
101123 else :
102124 classif = [other for other in restrictions if other in classification ]
103125 syms = [other for other in restrictions if other in symbols ]
126+ lbls = [other for other in restrictions if other in labels ]
104127 other_molecules = other_molecules [other_molecules ['classification' ].isin (classif )]
105128 other_atoms = other_atoms [other_atoms ['molecule' ].isin (other_molecules .index )]
106- other_atoms = other_atoms [other_atoms ['symbol' ].isin (syms )]
129+ if len (syms ) > 0 :
130+ other_atoms = other_atoms [other_atoms ['symbol' ].isin (syms )]
131+ if len (lbls ) > 0 :
132+ other_atoms = other_atoms [other_atoms ['label' ].isin (lbls )]
133+ del universe .atom ['label' ]
107134 return source_atoms , other_atoms , source_molecules , other_molecules , n
108135
109136
0 commit comments