@@ -159,7 +159,10 @@ def extra_repr(self):
159159
160160class MolDataset (torch .utils .data .Dataset ):
161161 '''A pytorch mappable dataset for molgrid training files.'''
162- def __init__ (self , * args , ** kwargs ):
162+ def __init__ (self , * args ,
163+ random_translation : float = 0.0 ,
164+ random_rotation : bool = False ,
165+ ** kwargs ):
163166 '''Initialize mappable MolGridDataset.
164167 :param input(s): File name(s) of training example files
165168 :param typers: A tuple of AtomTypers to use
@@ -173,9 +176,10 @@ def __init__(self, *args, **kwargs):
173176 :param ligmolcache: precalculated molcache2 file for ligand; if doesn't exist, will look in data_root
174177 '''
175178
179+ self ._random_translation , self ._random_rotation = random_translation , random_rotation
180+ print (self ._random_translation , self ._random_rotation )
176181 if 'typers' in kwargs :
177- typers = kwargs ['typers' ]
178- del kwargs ['typers' ]
182+ typers = kwargs .pop ('typers' )
179183 self .examples = mg .ExampleDataset (* typers ,** kwargs )
180184 self .typers = typers
181185 else :
@@ -184,39 +188,42 @@ def __init__(self, *args, **kwargs):
184188 self .types_files = list (args )
185189 self .examples .populate (self .types_files )
186190
187- self .num_labels = self .examples .num_labels ()
188-
189-
190191 def __len__ (self ):
191192 return len (self .examples )
192193
193194 def __getitem__ (self , idx ):
194195 ex = self .examples [idx ]
195- center = torch .tensor ([ i for i in ex .coord_sets [- 1 ].center ()] )
196+ center = torch .tensor (list ( ex .coord_sets [- 1 ].center ()) )
196197 coordinates = ex .merge_coordinates ()
198+ if self ._random_translation > 0 or self ._random_rotation :
199+ mg .Transform (ex .coord_sets [- 1 ].center (), self ._random_translation , self ._random_rotation ).forward (coordinates , coordinates )
197200 if coordinates .has_vector_types () and coordinates .size () > 0 :
198201 atomtypes = torch .tensor (coordinates .type_vector .tonumpy (),dtype = torch .long ).type ('torch.FloatTensor' )
199202 else :
200203 atomtypes = torch .tensor (coordinates .type_index .tonumpy (),dtype = torch .long ).type ('torch.FloatTensor' )
201204 coords = torch .tensor (coordinates .coords .tonumpy ())
205+ length = len (coords )
202206 radii = torch .tensor (coordinates .radii .tonumpy ())
203- labels = [ex .labels [lab ] for lab in range (self .num_labels )]
204- return center , coords , atomtypes , radii , labels
207+ labels = torch .tensor (ex .labels )
208+ return length , center , coords , atomtypes , radii , labels
209+
205210
206211 def __getstate__ (self ):
207- settings = self .examples .settings ()
212+ settings = self .examples .settings ()
208213 keyword_dict = {sett : getattr (settings , sett ) for sett in dir (settings ) if not sett .startswith ('__' )}
209214 if self .typers is not None : ## This will fail if self.typers is not none, need a way to pickle AtomTypers
210215 raise NotImplementedError ('MolDataset does not support pickling when not using the default Gnina atom typers, this uses %s' .format (str (self .typers )))
211216 keyword_dict ['typers' ] = self .typers
217+ keyword_dict ['random_translation' ] = self ._random_translation
218+ keyword_dict ['random_rotation' ] = self ._random_rotation
212219 return keyword_dict , self .types_files
213220
214221 def __setstate__ (self ,state ):
215222 kwargs = state [0 ]
216-
223+ self ._random_translation = kwargs .pop ('random_translation' )
224+ self ._random_rotation = kwargs .pop ('random_rotation' )
217225 if 'typers' in kwargs :
218- typers = kwargs ['typers' ]
219- del kwargs ['typers' ]
226+ typers = kwargs .pop ('typers' )
220227 self .examples = mg .ExampleDataset (* typers , ** kwargs )
221228 self .typers = typers
222229 else :
@@ -225,33 +232,19 @@ def __setstate__(self,state):
225232 self .types_files = list (state [1 ])
226233 self .examples .populate (self .types_files )
227234
235+
228236 self .num_labels = self .examples .num_labels ()
229237
230238 @staticmethod
231239 def collateMolDataset (batch ):
232240 '''collate_fn for use in torch.utils.data.Dataloader when using the MolDataset.
233241 Returns lengths, centers, coords, types, radii, labels all padded to fit maximum size of batch'''
234- lens = []
235- centers = []
236- lcoords = []
237- ltypes = []
238- lradii = []
239- labels = []
240- for center ,coords ,types ,radii ,label in batch :
241- lens .append (coords .shape [0 ])
242- centers .append (center )
243- lcoords .append (coords )
244- ltypes .append (types )
245- lradii .append (radii )
246- labels .append (torch .tensor (label ))
247-
248-
249- lengths = torch .tensor (lens )
250- lcoords = torch .nn .utils .rnn .pad_sequence (lcoords , batch_first = True )
251- ltypes = torch .nn .utils .rnn .pad_sequence (ltypes , batch_first = True )
252- lradii = torch .nn .utils .rnn .pad_sequence (lradii , batch_first = True )
253-
254- centers = torch .stack (centers ,dim = 0 )
255- labels = torch .stack (labels ,dim = 0 )
242+ batch_list = list (zip (* batch ))
243+ lengths = torch .tensor (batch_list [0 ])
244+ centers = torch .stack (batch_list [1 ], dim = 0 )
245+ coords = torch .nn .utils .rnn .pad_sequence (batch_list [2 ], batch_first = True )
246+ types = torch .nn .utils .rnn .pad_sequence (batch_list [3 ], batch_first = True )
247+ radii = torch .nn .utils .rnn .pad_sequence (batch_list [4 ], batch_first = True )
248+ labels = torch .stack (batch_list [5 ], dim = 0 )
256249
257- return lengths , centers , lcoords , ltypes , lradii , labels
250+ return lengths , centers , coords , types , radii , labels
0 commit comments