11import torch
22import molgrid as mg
33import types
4+ from itertools import islice
5+
46def tensor_as_grid (t ):
57 '''Return a Grid view of tensor t'''
68 gname = 'Grid'
@@ -157,7 +159,7 @@ def extra_repr(self):
157159 self .gmaker .get_resolution (), self .gmaker .get_dimension (), self .center [0 ], self .center [1 ], self .center [2 ])
158160
159161
160- class MolDataset (torch .utils .data .Dataset ):
162+ class MolMapDataset (torch .utils .data .Dataset ):
161163 '''A pytorch mappable dataset for molgrid training files.'''
162164 def __init__ (self , * args ,
163165 random_translation : float = 0.0 ,
@@ -177,7 +179,6 @@ def __init__(self, *args,
177179 '''
178180
179181 self ._random_translation , self ._random_rotation = random_translation , random_rotation
180- print (self ._random_translation , self ._random_rotation )
181182 if 'typers' in kwargs :
182183 typers = kwargs .pop ('typers' )
183184 self .examples = mg .ExampleDataset (* typers ,** kwargs )
@@ -212,7 +213,7 @@ def __getstate__(self):
212213 settings = self .examples .settings ()
213214 keyword_dict = {sett : getattr (settings , sett ) for sett in dir (settings ) if not sett .startswith ('__' )}
214215 if self .typers is not None : ## This will fail if self.typers is not none, need a way to pickle AtomTypers
215- raise NotImplementedError ('MolDataset does not support pickling when not using the default Gnina atom typers, this uses %s' .format (str (self .typers )))
216+ raise NotImplementedError ('MolMapDataset does not support pickling when not using the default Gnina atom typers, this uses %s' .format (str (self .typers )))
216217 keyword_dict ['typers' ] = self .typers
217218 keyword_dict ['random_translation' ] = self ._random_translation
218219 keyword_dict ['random_rotation' ] = self ._random_rotation
@@ -233,11 +234,9 @@ def __setstate__(self,state):
233234 self .examples .populate (self .types_files )
234235
235236
236- self .num_labels = self .examples .num_labels ()
237-
238237 @staticmethod
239238 def collateMolDataset (batch ):
240- '''collate_fn for use in torch.utils.data.Dataloader when using the MolDataset .
239+ '''collate_fn for use in torch.utils.data.Dataloader when using the MolMapDataset .
241240 Returns lengths, centers, coords, types, radii, labels all padded to fit maximum size of batch'''
242241 batch_list = list (zip (* batch ))
243242 lengths = torch .tensor (batch_list [0 ])
@@ -248,3 +247,110 @@ def collateMolDataset(batch):
248247 labels = torch .stack (batch_list [5 ], dim = 0 )
249248
250249 return lengths , centers , coords , types , radii , labels
250+
251+ class MolIterDataset (torch .utils .data .IterableDataset ):
252+ '''A pytorch iterable dataset for molgrid training files. Use with a DataLoader(batch_size=None) for best results.'''
253+ def __init__ (self , * args ,
254+ random_translation : float = 0.0 ,
255+ random_rotation : bool = False ,
256+ ** kwargs ):
257+ '''Initialize mappable MolGridDataset.
258+ :param input(s): File name(s) of training example files
259+ :param typers: A tuple of AtomTypers to use
260+ :type typers: tuple
261+ :param cache_structs: retain coordinates in memory for faster training
262+ :param add_hydrogens: protonate molecules read using openbabel
263+ :param duplicate_first: clone the first coordinate set to be paired with each of the remaining (receptor-ligand pairs)
264+ :param make_vector_types: convert index types into one-hot encoded vector types
265+ :param data_root: prefix for data files
266+ :param recmolcache: precalculated molcache2 file for receptor (first molecule); if doesn't exist, will look in data _root
267+ :param ligmolcache: precalculated molcache2 file for ligand; if doesn't exist, will look in data_root
268+ '''
269+
270+ # molgrid.set_random_seed(kwargs['random_seed'])
271+ self ._random_translation , self ._random_rotation = random_translation , random_rotation
272+ if 'typers' in kwargs :
273+ typers = kwargs .pop ('typers' )
274+ self .examples = mg .ExampleProvider (* typers ,** kwargs )
275+ self .typers = typers
276+ else :
277+ self .examples = mg .ExampleProvider (** kwargs )
278+ self .typers = None
279+ self .types_files = list (args )
280+ self .examples .populate (self .types_files )
281+
282+ self ._num_labels = self .examples .num_labels ()
283+
284+ def generate (self ):
285+ for batch in self .examples :
286+ yield self .batch_to_tensors (batch )
287+
288+ def batch_to_tensors (self , batch ):
289+ batch_lengths = torch .zeros (len (batch ), dtype = torch .int64 )
290+ batch_centers = torch .zeros ((len (batch ), 3 ), dtype = torch .float32 )
291+ batch_coords = []
292+ batch_atomtypes = []
293+ batch_radii = []
294+ batch_labels = torch .zeros ((len (batch ),self ._num_labels ), dtype = torch .float32 )
295+ for idx , ex in enumerate (batch ):
296+ length , center , coords , atomtypes , radii , labels = self .example_to_tensor (ex )
297+ batch_lengths [idx ] = length
298+ batch_centers [idx ,:] = center
299+ batch_coords .append (coords )
300+ batch_atomtypes .append (atomtypes )
301+ batch_radii .append (radii )
302+ batch_labels [idx ,:] = labels
303+ pad_coords = torch .nn .utils .rnn .pad_sequence (batch_coords , batch_first = True )
304+ pad_atomtypes = torch .nn .utils .rnn .pad_sequence (batch_atomtypes , batch_first = True )
305+ pad_radii = torch .nn .utils .rnn .pad_sequence (batch_radii , batch_first = True )
306+ return batch_lengths , batch_centers , pad_coords , pad_atomtypes , pad_radii , batch_labels
307+
308+
309+ def example_to_tensor (self , ex ):
310+ center = torch .tensor (list (ex .coord_sets [- 1 ].center ()))
311+ coordinates = ex .merge_coordinates ()
312+ if self ._random_translation > 0 or self ._random_rotation :
313+ mg .Transform (ex .coord_sets [- 1 ].center (), self ._random_translation , self ._random_rotation ).forward (coordinates , coordinates )
314+ if coordinates .has_vector_types () and coordinates .size () > 0 :
315+ atomtypes = torch .tensor (coordinates .type_vector .tonumpy (),dtype = torch .long ).type ('torch.FloatTensor' )
316+ else :
317+ atomtypes = torch .tensor (coordinates .type_index .tonumpy (),dtype = torch .long ).type ('torch.FloatTensor' )
318+ coords = torch .tensor (coordinates .coords .tonumpy ())
319+ length = len (coords )
320+ radii = torch .tensor (coordinates .radii .tonumpy ())
321+ labels = torch .tensor (ex .labels )
322+ return length , center , coords , atomtypes , radii , labels
323+
324+ def __iter__ (self ):
325+ worker_info = torch .utils .data .get_worker_info ()
326+ if worker_info is None :
327+ return self .generate ()
328+ dataset = worker_info .dataset
329+ worker_id = worker_info .id
330+ n_workers = worker_info .num_workers
331+
332+ return islice (self .generate (), worker_id , None , n_workers )
333+
334+ def __getstate__ (self ):
335+ settings = self .examples .settings ()
336+ keyword_dict = {sett : getattr (settings , sett ) for sett in dir (settings ) if not sett .startswith ('__' )}
337+ if self .typers is not None : ## This will fail if self.typers is not none, need a way to pickle AtomTypers
338+ raise NotImplementedError ('MolIterDataset does not support pickling when not using the default Gnina atom typers, this uses %s' .format (str (self .typers )))
339+ keyword_dict ['typers' ] = self .typers
340+ keyword_dict ['random_translation' ] = self ._random_translation
341+ keyword_dict ['random_rotation' ] = self ._random_rotation
342+ return keyword_dict , self .types_files
343+
344+ def __setstate__ (self ,state ):
345+ kwargs = state [0 ]
346+ self ._random_translation = kwargs .pop ('random_translation' )
347+ self ._random_rotation = kwargs .pop ('random_rotation' )
348+ if 'typers' in kwargs :
349+ typers = kwargs .pop ('typers' )
350+ self .examples = mg .ExampleProvider (* typers , ** kwargs )
351+ self .typers = typers
352+ else :
353+ self .examples = mg .ExampleProvider (** kwargs )
354+ self .typers = None
355+ self .types_files = list (state [1 ])
356+ self .examples .populate (self .types_files )
0 commit comments