1- import os
1+ import torch , os
22
3- import torch
43from torch_geometric .data import InMemoryDataset
54
6-
75class StructureDataset (InMemoryDataset ):
86 def __init__ (
97 self ,
10- root ,
11- processed_data_path ,
12- transform = None ,
13- pre_transform = None ,
8+ root ,
9+ processed_data_path ,
10+ transform = None ,
11+ pre_transform = None ,
1412 pre_filter = None ,
13+ device = None
1514 ):
1615 self .root = root
1716 self .processed_data_path = processed_data_path
18- super (StructureDataset , self ).__init__ (
19- root , transform , pre_transform , pre_filter
20- )
21- self .data , self .slices = torch .load (self .processed_paths [0 ])
22-
17+ super (StructureDataset , self ).__init__ (root , transform , pre_transform , pre_filter )
18+
19+ if device is None :
20+ try :
21+ self .data , self .slices = torch .load (self .processed_paths [0 ])
22+ except :
23+ self .data , self .slices = torch .load (self .processed_paths [0 ], map_location = torch .device ('cpu' ))
24+ else :
25+ if device == 'cpu' :
26+ self .data , self .slices = torch .load (self .processed_paths [0 ], map_location = torch .device (device ))
27+ else :
28+ self .data , self .slices = torch .load (self .processed_paths [0 ])
29+
2330 @property
2431 def raw_file_names (self ):
25- """
26- The name of the files in the self.raw_dir folder
32+ '''
33+ The name of the files in the self.raw_dir folder
2734 that must be present in order to skip downloading.
28- """
35+ '''
2936 return []
3037
3138 def download (self ):
32- """
39+ '''
3340 Download required data files; to be implemented
34- """
41+ '''
3542 pass
3643
3744 @property
@@ -40,12 +47,11 @@ def processed_dir(self):
4047
4148 @property
4249 def processed_file_names (self ):
43- """
44- The name of the files in the self.processed_dir
50+ '''
51+ The name of the files in the self.processed_dir
4552 folder that must be present in order to skip processing.
46- """
53+ '''
4754 return ["data.pt" ]
4855
49-
5056class LargeStructureDataset (InMemoryDataset ):
51- pass
57+ pass
0 commit comments