Skip to content

Commit 921de63

Browse files
authored
Merge pull request #122 from gnina/MolDataset
MolDataset -> MolMapDataset and MolIterDataset
2 parents efc5cb3 + c355145 commit 921de63

2 files changed

Lines changed: 182 additions & 9 deletions

File tree

python/torch_bindings.py

Lines changed: 112 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
22
import molgrid as mg
33
import types
4+
from itertools import islice
5+
46
def tensor_as_grid(t):
57
'''Return a Grid view of tensor t'''
68
gname = 'Grid'
@@ -157,7 +159,7 @@ def extra_repr(self):
157159
self.gmaker.get_resolution(), self.gmaker.get_dimension(), self.center[0], self.center[1], self.center[2])
158160

159161

160-
class MolDataset(torch.utils.data.Dataset):
162+
class MolMapDataset(torch.utils.data.Dataset):
161163
'''A pytorch mappable dataset for molgrid training files.'''
162164
def __init__(self, *args,
163165
random_translation: float=0.0,
@@ -177,7 +179,6 @@ def __init__(self, *args,
177179
'''
178180

179181
self._random_translation, self._random_rotation = random_translation, random_rotation
180-
print(self._random_translation, self._random_rotation)
181182
if 'typers' in kwargs:
182183
typers = kwargs.pop('typers')
183184
self.examples = mg.ExampleDataset(*typers,**kwargs)
@@ -212,7 +213,7 @@ def __getstate__(self):
212213
settings = self.examples.settings()
213214
keyword_dict = {sett: getattr(settings, sett) for sett in dir(settings) if not sett.startswith('__')}
214215
if self.typers is not None: ## This will fail if self.typers is not none, need a way to pickle AtomTypers
215-
raise NotImplementedError('MolDataset does not support pickling when not using the default Gnina atom typers, this uses %s'.format(str(self.typers)))
216+
raise NotImplementedError('MolMapDataset does not support pickling when not using the default Gnina atom typers, this uses %s'.format(str(self.typers)))
216217
keyword_dict['typers'] = self.typers
217218
keyword_dict['random_translation'] = self._random_translation
218219
keyword_dict['random_rotation'] = self._random_rotation
@@ -233,11 +234,9 @@ def __setstate__(self,state):
233234
self.examples.populate(self.types_files)
234235

235236

236-
self.num_labels = self.examples.num_labels()
237-
238237
@staticmethod
239238
def collateMolDataset(batch):
240-
'''collate_fn for use in torch.utils.data.Dataloader when using the MolDataset.
239+
'''collate_fn for use in torch.utils.data.Dataloader when using the MolMapDataset.
241240
Returns lengths, centers, coords, types, radii, labels all padded to fit maximum size of batch'''
242241
batch_list = list(zip(*batch))
243242
lengths = torch.tensor(batch_list[0])
@@ -248,3 +247,110 @@ def collateMolDataset(batch):
248247
labels = torch.stack(batch_list[5], dim=0)
249248

250249
return lengths, centers, coords, types, radii, labels
250+
251+
class MolIterDataset(torch.utils.data.IterableDataset):
252+
'''A pytorch iterable dataset for molgrid training files. Use with a DataLoader(batch_size=None) for best results.'''
253+
def __init__(self, *args,
254+
random_translation: float=0.0,
255+
random_rotation: bool=False,
256+
**kwargs):
257+
'''Initialize mappable MolGridDataset.
258+
:param input(s): File name(s) of training example files
259+
:param typers: A tuple of AtomTypers to use
260+
:type typers: tuple
261+
:param cache_structs: retain coordinates in memory for faster training
262+
:param add_hydrogens: protonate molecules read using openbabel
263+
:param duplicate_first: clone the first coordinate set to be paired with each of the remaining (receptor-ligand pairs)
264+
:param make_vector_types: convert index types into one-hot encoded vector types
265+
:param data_root: prefix for data files
266+
:param recmolcache: precalculated molcache2 file for receptor (first molecule); if doesn't exist, will look in data _root
267+
:param ligmolcache: precalculated molcache2 file for ligand; if doesn't exist, will look in data_root
268+
'''
269+
270+
# molgrid.set_random_seed(kwargs['random_seed'])
271+
self._random_translation, self._random_rotation = random_translation, random_rotation
272+
if 'typers' in kwargs:
273+
typers = kwargs.pop('typers')
274+
self.examples = mg.ExampleProvider(*typers,**kwargs)
275+
self.typers = typers
276+
else:
277+
self.examples = mg.ExampleProvider(**kwargs)
278+
self.typers = None
279+
self.types_files = list(args)
280+
self.examples.populate(self.types_files)
281+
282+
self._num_labels = self.examples.num_labels()
283+
284+
def generate(self):
285+
for batch in self.examples:
286+
yield self.batch_to_tensors(batch)
287+
288+
def batch_to_tensors(self, batch):
289+
batch_lengths = torch.zeros(len(batch), dtype=torch.int64)
290+
batch_centers = torch.zeros((len(batch), 3), dtype=torch.float32)
291+
batch_coords = []
292+
batch_atomtypes = []
293+
batch_radii = []
294+
batch_labels = torch.zeros((len(batch),self._num_labels), dtype=torch.float32)
295+
for idx, ex in enumerate(batch):
296+
length, center, coords, atomtypes, radii, labels = self.example_to_tensor(ex)
297+
batch_lengths[idx] = length
298+
batch_centers[idx,:] = center
299+
batch_coords.append(coords)
300+
batch_atomtypes.append(atomtypes)
301+
batch_radii.append(radii)
302+
batch_labels[idx,:] = labels
303+
pad_coords = torch.nn.utils.rnn.pad_sequence(batch_coords, batch_first=True)
304+
pad_atomtypes = torch.nn.utils.rnn.pad_sequence(batch_atomtypes, batch_first=True)
305+
pad_radii = torch.nn.utils.rnn.pad_sequence(batch_radii, batch_first=True)
306+
return batch_lengths, batch_centers, pad_coords, pad_atomtypes, pad_radii, batch_labels
307+
308+
309+
def example_to_tensor(self, ex):
310+
center = torch.tensor(list(ex.coord_sets[-1].center()))
311+
coordinates = ex.merge_coordinates()
312+
if self._random_translation > 0 or self._random_rotation:
313+
mg.Transform(ex.coord_sets[-1].center(), self._random_translation, self._random_rotation).forward(coordinates, coordinates)
314+
if coordinates.has_vector_types() and coordinates.size() > 0:
315+
atomtypes = torch.tensor(coordinates.type_vector.tonumpy(),dtype=torch.long).type('torch.FloatTensor')
316+
else:
317+
atomtypes = torch.tensor(coordinates.type_index.tonumpy(),dtype=torch.long).type('torch.FloatTensor')
318+
coords = torch.tensor(coordinates.coords.tonumpy())
319+
length = len(coords)
320+
radii = torch.tensor(coordinates.radii.tonumpy())
321+
labels = torch.tensor(ex.labels)
322+
return length, center, coords, atomtypes, radii, labels
323+
324+
def __iter__(self):
325+
worker_info = torch.utils.data.get_worker_info()
326+
if worker_info is None:
327+
return self.generate()
328+
dataset = worker_info.dataset
329+
worker_id = worker_info.id
330+
n_workers = worker_info.num_workers
331+
332+
return islice(self.generate(), worker_id, None, n_workers)
333+
334+
def __getstate__(self):
335+
settings = self.examples.settings()
336+
keyword_dict = {sett: getattr(settings, sett) for sett in dir(settings) if not sett.startswith('__')}
337+
if self.typers is not None: ## This will fail if self.typers is not none, need a way to pickle AtomTypers
338+
raise NotImplementedError('MolIterDataset does not support pickling when not using the default Gnina atom typers, this uses %s'.format(str(self.typers)))
339+
keyword_dict['typers'] = self.typers
340+
keyword_dict['random_translation'] = self._random_translation
341+
keyword_dict['random_rotation'] = self._random_rotation
342+
return keyword_dict, self.types_files
343+
344+
def __setstate__(self,state):
345+
kwargs=state[0]
346+
self._random_translation = kwargs.pop('random_translation')
347+
self._random_rotation = kwargs.pop('random_rotation')
348+
if 'typers' in kwargs:
349+
typers = kwargs.pop('typers')
350+
self.examples = mg.ExampleProvider(*typers, **kwargs)
351+
self.typers = typers
352+
else:
353+
self.examples = mg.ExampleProvider(**kwargs)
354+
self.typers = None
355+
self.types_files = list(state[1])
356+
self.examples.populate(self.types_files)

test/test_example_provider.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -354,12 +354,12 @@ def test_example_provider_iterator_interface():
354354
break
355355

356356

357-
def test_pytorch_dataset():
357+
def test_pytorch_mapdataset():
358358
fname = datadir + "/small.types"
359359

360360
e = molgrid.ExampleProvider(data_root=datadir + "/structs")
361361
e.populate(fname)
362-
m = molgrid.MolDataset(fname, data_root=datadir + "/structs")
362+
m = molgrid.MolMapDataset(fname, data_root=datadir + "/structs")
363363

364364
assert len(m) == 1000
365365

@@ -384,7 +384,7 @@ def test_pytorch_dataset():
384384

385385
'''Testing out the collate_fn when used with torch.utils.data.DataLoader'''
386386
torch_loader = torch.utils.data.DataLoader(
387-
m, batch_size=8, collate_fn=molgrid.MolDataset.collateMolDataset)
387+
m, batch_size=8, collate_fn=molgrid.MolMapDataset.collateMolDataset)
388388
iterator = iter(torch_loader)
389389
next(iterator)
390390
lengths, center, coords, types, radii, labels = next(iterator)
@@ -422,6 +422,73 @@ def test_pytorch_dataset():
422422
singlegrid = molgrid.MGrid4f(*shape)
423423
gmaker.forward(ex, singlegrid.cpu())
424424
np.testing.assert_allclose(mgrid[2].tonumpy(),singlegrid.tonumpy(),atol=1e-5)
425+
426+
def test_pytorch_iterdataset():
427+
fname = datadir + "/small.types"
428+
429+
BSIZE = 25
430+
e = molgrid.ExampleProvider(data_root=datadir + "/structs", default_batch_size=BSIZE)
431+
e.populate(fname)
432+
m = molgrid.MolIterDataset(fname, data_root=datadir + "/structs", default_batch_size=BSIZE)
433+
m_iter = iter(m)
434+
435+
ex = e.next()
436+
coordinates = ex.merge_coordinates()
437+
438+
lengths, centers, coords, types, radii, labels = next(m_iter)
439+
440+
assert list(centers.shape) == [BSIZE,3]
441+
np.testing.assert_allclose(coords[0,:lengths[0],:], coordinates.coords.tonumpy())
442+
np.testing.assert_allclose(types[0,:lengths[0]], coordinates.type_index.tonumpy())
443+
np.testing.assert_allclose(radii[0,:lengths[0]], coordinates.radii.tonumpy())
444+
445+
assert len(labels) == BSIZE
446+
assert len(labels[0]) == 3
447+
assert labels[0,0] == 1
448+
np.testing.assert_allclose(labels[0,1], 6.05)
449+
np.testing.assert_allclose(labels[0,-1], 0.162643)
450+
451+
# ensure it works with more than 1 worker
452+
m.examples.reset()
453+
torch_loader = torch.utils.data.DataLoader(
454+
m, batch_size=None, num_workers=2)
455+
iterator = iter(torch_loader)
456+
next(iterator)
457+
lengths, center, coords, types, radii, labels = next(iterator)
458+
assert len(lengths) == BSIZE
459+
assert center.shape[0] == BSIZE
460+
assert coords.shape[0] == BSIZE
461+
assert types.shape[0] == BSIZE
462+
assert radii.shape[0] == BSIZE
463+
assert labels.shape[0] == BSIZE
464+
465+
e.reset()
466+
e.next_batch()
467+
ex = e.next_batch()
468+
coordinates = ex[2].merge_coordinates()
469+
np.testing.assert_allclose(center[2], np.array(list(ex[2].coord_sets[-1].center())))
470+
np.testing.assert_allclose(coords[2,:lengths[2]], coordinates.coords.tonumpy())
471+
np.testing.assert_allclose(types[2,:lengths[2]], coordinates.type_index.tonumpy())
472+
np.testing.assert_allclose(radii[2,:lengths[2]], coordinates.radii.tonumpy())
473+
assert len(labels[2]) == e.num_labels()
474+
assert labels[2,0] == ex[2].labels[0]
475+
assert labels[2,1] == ex[2].labels[1]
476+
477+
gmaker = molgrid.GridMaker()
478+
shape = gmaker.grid_dimensions(e.num_types())
479+
mgrid = molgrid.MGrid5f(BSIZE,*shape)
480+
481+
gmaker.forward(center, coords, types, radii, mgrid.cpu())
482+
483+
mgridg = molgrid.MGrid5f(BSIZE,*shape)
484+
gmaker.forward(center.cuda(), coords.cuda(), types.cuda(), radii.cuda(), mgridg.gpu())
485+
486+
np.testing.assert_allclose(mgrid.tonumpy(),mgridg.tonumpy(),atol=1e-5)
487+
488+
#compare against standard provider
489+
egrid = molgrid.MGrid5f(BSIZE,*shape)
490+
gmaker.forward(ex, egrid.cpu())
491+
np.testing.assert_allclose(mgridg.tonumpy(),egrid.tonumpy(),atol=1e-5)
425492

426493

427494
def test_duplicated_examples():

0 commit comments

Comments
 (0)