Skip to content

Commit efc5cb3

Browse files
authored
Merge pull request #121 from gnina/MolDataset
Updated MolDataset
2 parents 5a642b1 + 0b10133 commit efc5cb3

2 files changed

Lines changed: 32 additions & 39 deletions

File tree

python/torch_bindings.py

Lines changed: 29 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,10 @@ def extra_repr(self):
159159

160160
class MolDataset(torch.utils.data.Dataset):
161161
'''A pytorch mappable dataset for molgrid training files.'''
162-
def __init__(self, *args, **kwargs):
162+
def __init__(self, *args,
163+
random_translation: float=0.0,
164+
random_rotation: bool=False,
165+
**kwargs):
163166
'''Initialize mappable MolGridDataset.
164167
:param input(s): File name(s) of training example files
165168
:param typers: A tuple of AtomTypers to use
@@ -173,9 +176,10 @@ def __init__(self, *args, **kwargs):
173176
:param ligmolcache: precalculated molcache2 file for ligand; if doesn't exist, will look in data_root
174177
'''
175178

179+
self._random_translation, self._random_rotation = random_translation, random_rotation
180+
print(self._random_translation, self._random_rotation)
176181
if 'typers' in kwargs:
177-
typers = kwargs['typers']
178-
del kwargs['typers']
182+
typers = kwargs.pop('typers')
179183
self.examples = mg.ExampleDataset(*typers,**kwargs)
180184
self.typers = typers
181185
else:
@@ -184,39 +188,42 @@ def __init__(self, *args, **kwargs):
184188
self.types_files = list(args)
185189
self.examples.populate(self.types_files)
186190

187-
self.num_labels = self.examples.num_labels()
188-
189-
190191
def __len__(self):
191192
return len(self.examples)
192193

193194
def __getitem__(self, idx):
194195
ex = self.examples[idx]
195-
center = torch.tensor([i for i in ex.coord_sets[-1].center()])
196+
center = torch.tensor(list(ex.coord_sets[-1].center()))
196197
coordinates = ex.merge_coordinates()
198+
if self._random_translation > 0 or self._random_rotation:
199+
mg.Transform(ex.coord_sets[-1].center(), self._random_translation, self._random_rotation).forward(coordinates, coordinates)
197200
if coordinates.has_vector_types() and coordinates.size() > 0:
198201
atomtypes = torch.tensor(coordinates.type_vector.tonumpy(),dtype=torch.long).type('torch.FloatTensor')
199202
else:
200203
atomtypes = torch.tensor(coordinates.type_index.tonumpy(),dtype=torch.long).type('torch.FloatTensor')
201204
coords = torch.tensor(coordinates.coords.tonumpy())
205+
length = len(coords)
202206
radii = torch.tensor(coordinates.radii.tonumpy())
203-
labels = [ex.labels[lab] for lab in range(self.num_labels)]
204-
return center, coords, atomtypes, radii, labels
207+
labels = torch.tensor(ex.labels)
208+
return length, center, coords, atomtypes, radii, labels
209+
205210

206211
def __getstate__(self):
207-
settings = self.examples.settings()
212+
settings = self.examples.settings()
208213
keyword_dict = {sett: getattr(settings, sett) for sett in dir(settings) if not sett.startswith('__')}
209214
if self.typers is not None: ## This will fail if self.typers is not none, need a way to pickle AtomTypers
210215
raise NotImplementedError('MolDataset does not support pickling when not using the default Gnina atom typers, this uses %s'.format(str(self.typers)))
211216
keyword_dict['typers'] = self.typers
217+
keyword_dict['random_translation'] = self._random_translation
218+
keyword_dict['random_rotation'] = self._random_rotation
212219
return keyword_dict, self.types_files
213220

214221
def __setstate__(self,state):
215222
kwargs=state[0]
216-
223+
self._random_translation = kwargs.pop('random_translation')
224+
self._random_rotation = kwargs.pop('random_rotation')
217225
if 'typers' in kwargs:
218-
typers = kwargs['typers']
219-
del kwargs['typers']
226+
typers = kwargs.pop('typers')
220227
self.examples = mg.ExampleDataset(*typers, **kwargs)
221228
self.typers = typers
222229
else:
@@ -225,33 +232,19 @@ def __setstate__(self,state):
225232
self.types_files = list(state[1])
226233
self.examples.populate(self.types_files)
227234

235+
228236
self.num_labels = self.examples.num_labels()
229237

230238
@staticmethod
231239
def collateMolDataset(batch):
232240
'''collate_fn for use in torch.utils.data.Dataloader when using the MolDataset.
233241
Returns lengths, centers, coords, types, radii, labels all padded to fit maximum size of batch'''
234-
lens = []
235-
centers = []
236-
lcoords = []
237-
ltypes = []
238-
lradii = []
239-
labels = []
240-
for center,coords,types,radii,label in batch:
241-
lens.append(coords.shape[0])
242-
centers.append(center)
243-
lcoords.append(coords)
244-
ltypes.append(types)
245-
lradii.append(radii)
246-
labels.append(torch.tensor(label))
247-
248-
249-
lengths = torch.tensor(lens)
250-
lcoords = torch.nn.utils.rnn.pad_sequence(lcoords, batch_first=True)
251-
ltypes = torch.nn.utils.rnn.pad_sequence(ltypes, batch_first=True)
252-
lradii = torch.nn.utils.rnn.pad_sequence(lradii, batch_first=True)
253-
254-
centers = torch.stack(centers,dim=0)
255-
labels = torch.stack(labels,dim=0)
242+
batch_list = list(zip(*batch))
243+
lengths = torch.tensor(batch_list[0])
244+
centers = torch.stack(batch_list[1], dim=0)
245+
coords = torch.nn.utils.rnn.pad_sequence(batch_list[2], batch_first=True)
246+
types = torch.nn.utils.rnn.pad_sequence(batch_list[3], batch_first=True)
247+
radii = torch.nn.utils.rnn.pad_sequence(batch_list[4], batch_first=True)
248+
labels = torch.stack(batch_list[5], dim=0)
256249

257-
return lengths, centers, lcoords, ltypes, lradii, labels
250+
return lengths, centers, coords, types, radii, labels

test/test_example_provider.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def test_pytorch_dataset():
366366
ex = e.next()
367367
coordinates = ex.merge_coordinates()
368368

369-
center, coords, types, radii, labels = m[0]
369+
lengths, center, coords, types, radii, labels = m[0]
370370

371371
assert list(center.shape) == [3]
372372
np.testing.assert_allclose(coords, coordinates.coords.tonumpy())
@@ -378,7 +378,7 @@ def test_pytorch_dataset():
378378
np.testing.assert_allclose(labels[1], 6.05)
379379
np.testing.assert_allclose(labels[-1], 0.162643)
380380

381-
center, coords, types, radii, labels = m[-1]
381+
lengths, center, coords, types, radii, labels = m[-1]
382382
assert labels[0] == 0
383383
np.testing.assert_allclose(labels[1], -10.3)
384384

@@ -396,7 +396,7 @@ def test_pytorch_dataset():
396396
assert radii.shape[0] == 8
397397
assert labels.shape[0] == 8
398398

399-
mcenter, mcoords, mtypes, mradii, mlabels = m[10]
399+
mlengths, mcenter, mcoords, mtypes, mradii, mlabels = m[10]
400400
np.testing.assert_allclose(center[2], mcenter)
401401
np.testing.assert_allclose(coords[2][:lengths[2]], mcoords)
402402
np.testing.assert_allclose(types[2][:lengths[2]], mtypes)

0 commit comments

Comments
 (0)