Skip to content

Commit 0aa3fa3

Browse files
committed
bugfix: processor device, in-place ops and more
1 parent 0467914 commit 0aa3fa3

4 files changed

Lines changed: 121 additions & 113 deletions

File tree

matdeeplearn/common/data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def get_dataloader(
103103
batch_size: int,
104104
num_workers: int = 0,
105105
sampler=None,
106+
shuffle=True
106107
):
107108
"""
108109
Returns a single dataloader for a given dataset
@@ -124,7 +125,7 @@ def get_dataloader(
124125
loader = DataLoader(
125126
dataset,
126127
batch_size=batch_size,
127-
shuffle=(sampler is None),
128+
shuffle=shuffle,
128129
num_workers=num_workers,
129130
sampler=sampler,
130131
)
Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,44 @@
1-
import os
1+
import torch, os
22

3-
import torch
43
from torch_geometric.data import InMemoryDataset
54

6-
75
class StructureDataset(InMemoryDataset):
86
def __init__(
97
self,
10-
root,
11-
processed_data_path,
12-
transform=None,
13-
pre_transform=None,
8+
root,
9+
processed_data_path,
10+
transform=None,
11+
pre_transform=None,
1412
pre_filter=None,
13+
device=None
1514
):
1615
self.root = root
1716
self.processed_data_path = processed_data_path
18-
super(StructureDataset, self).__init__(
19-
root, transform, pre_transform, pre_filter
20-
)
21-
self.data, self.slices = torch.load(self.processed_paths[0])
22-
17+
super(StructureDataset, self).__init__(root, transform, pre_transform, pre_filter)
18+
19+
if device is None:
20+
try:
21+
self.data, self.slices = torch.load(self.processed_paths[0])
22+
except:
23+
self.data, self.slices = torch.load(self.processed_paths[0], map_location=torch.device('cpu'))
24+
else:
25+
if device == 'cpu':
26+
self.data, self.slices = torch.load(self.processed_paths[0], map_location=torch.device(device))
27+
else:
28+
self.data, self.slices = torch.load(self.processed_paths[0])
29+
2330
@property
2431
def raw_file_names(self):
25-
"""
26-
The name of the files in the self.raw_dir folder
32+
'''
33+
The name of the files in the self.raw_dir folder
2734
that must be present in order to skip downloading.
28-
"""
35+
'''
2936
return []
3037

3138
def download(self):
32-
"""
39+
'''
3340
Download required data files; to be implemented
34-
"""
41+
'''
3542
pass
3643

3744
@property
@@ -40,12 +47,11 @@ def processed_dir(self):
4047

4148
@property
4249
def processed_file_names(self):
43-
"""
44-
The name of the files in the self.processed_dir
50+
'''
51+
The name of the files in the self.processed_dir
4552
folder that must be present in order to skip processing.
46-
"""
53+
'''
4754
return ["data.pt"]
4855

49-
5056
class LargeStructureDataset(InMemoryDataset):
51-
pass
57+
pass

0 commit comments

Comments
 (0)