1- import torch
21import warnings
32
3+ import torch
44from torch .utils .data import random_split
55from torch_geometric .loader import DataLoader
66
7- from matdeeplearn .preprocessor .transforms import *
8- from matdeeplearn .preprocessor .datasets import StructureDataset , LargeStructureDataset
7+ from matdeeplearn .preprocessor .datasets import LargeStructureDataset , StructureDataset
8+ from matdeeplearn .preprocessor .transforms import GetY
9+
910
1011# train test split
1112def dataset_split (
1213 dataset ,
1314 train_size : float = 0.8 ,
1415 valid_size : float = 0.05 ,
1516 test_size : float = 0.15 ,
16- seed : int = 1234
17+ seed : int = 1234 ,
1718):
18- '''
19+ """
1920 Splits an input dataset into 3 subsets: train, validation, test.
2021 Requires train_size + valid_size + test_size = 1
2122
2223 Parameters
2324 ----------
2425 dataset: matdeeplearn.preprocessor.datasets.StructureDataset
2526 a dataset object that contains the target data
26-
27+
2728 train_size: float
2829 a float between 0.0 and 1.0 that represents the proportion
2930 of the dataset to use as the training set
3031
3132 valid_size: float
3233 a float between 0.0 and 1.0 that represents the proportion
3334 of the dataset to use as the validation set
34-
35+
3536 test_size: float
3637 a float between 0.0 and 1.0 that represents the proportion
3738 of the dataset to use as the test set
38- '''
39+ """
3940 if train_size + valid_size + test_size != 1 :
4041 warnings .warn ("Invalid sizes detected. Using default split of 80/5/15." )
4142 train_size , valid_size , test_size = 0.8 , 0.05 , 0.15
4243
4344 dataset_size = len (dataset )
44-
45+
4546 train_len = int (train_size * dataset_size )
4647 valid_len = int (valid_size * dataset_size )
4748 test_len = int (test_size * dataset_size )
4849 unused_len = dataset_size - train_len - valid_len - test_len
4950
50- (
51- train_dataset ,
52- val_dataset ,
53- test_dataset ,
54- unused_dataset
55- ) = random_split (
51+ (train_dataset , val_dataset , test_dataset , unused_dataset ) = random_split (
5652 dataset ,
5753 [train_len , valid_len , test_len , unused_len ],
58- generator = torch .Generator ().manual_seed (seed )
54+ generator = torch .Generator ().manual_seed (seed ),
5955 )
6056
6157 return train_dataset , val_dataset , test_dataset
6258
59+
6360def get_dataset (
64- data_path ,
65- target_index : int = 0 ,
66- transform_type = 'GetY' ,
67- large_dataset = False
61+ data_path , target_index : int = 0 , transform_type = "GetY" , large_dataset = False
6862):
69- '''
63+ """
7064 get dataset according to data_path
7165 this assumes that the data has already been processed and
7266 data.pt file exists in data_path/processed/ folder
7367
7468 Parameters
7569 ----------
76-
70+
7771 data_path: str
7872 path to the folder containing data.pt file
7973
@@ -85,13 +79,13 @@ def get_dataset(
8579 the current run/experiment
8680
8781 transform_type: transformation function/class to be applied
88- '''
89-
82+ """
83+
9084 # set transform method
91- if transform_type == ' GetY' :
85+ if transform_type == " GetY" :
9286 T = GetY
9387 else :
94- raise ValueError (' No such transform found for {transform}' )
88+ raise ValueError (" No such transform found for {transform}" )
9589
9690 # check if large dataset is needed
9791 if large_dataset :
@@ -101,37 +95,38 @@ def get_dataset(
10195
10296 transform = T (index = target_index )
10397
104- return Dataset (data_path , processed_data_path = '' , transform = transform )
98+ return Dataset (data_path , processed_data_path = "" , transform = transform )
99+
105100
106101def get_dataloader (
107102 dataset ,
108103 batch_size : int ,
109104 num_workers : int = 0 ,
110- sampler = None ,
105+ sampler = None ,
111106):
112- '''
107+ """
113108 Returns a single dataloader for a given dataset
114109
115110 Parameters
116111 ----------
117112 dataset: matdeeplearn.preprocessor.datasets.StructureDataset
118113 a dataset object that contains the target data
119-
114+
120115 batch_size: int
121116 size of each batch
122117
123118 num_workers: int
124119 how many subprocesses to use for data loading. 0 means that
125120 the data will be loaded in the main process.
126- '''
121+ """
127122
128123 # load data
129124 loader = DataLoader (
130125 dataset ,
131126 batch_size = batch_size ,
132127 shuffle = (sampler is None ),
133128 num_workers = num_workers ,
134- sampler = sampler
129+ sampler = sampler ,
135130 )
136131
137- return loader
132+ return loader
0 commit comments