Skip to content

Commit 35ea15f

Browse files
committed
updated processing with cleaner fixes
1 parent 0aa3fa3 commit 35ea15f

4 files changed

Lines changed: 59 additions & 60 deletions

File tree

data/test_data/processed/data.pt

2.1 MB
Binary file not shown.

matdeeplearn/preprocessor/datasets.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,27 @@ def __init__(
1616
self.processed_data_path = processed_data_path
1717
super(StructureDataset, self).__init__(root, transform, pre_transform, pre_filter)
1818

19-
if device is None:
20-
try:
21-
self.data, self.slices = torch.load(self.processed_paths[0])
22-
except:
23-
self.data, self.slices = torch.load(self.processed_paths[0], map_location=torch.device('cpu'))
19+
if not torch.cuda.is_available() or device == "cpu":
20+
self.data, self.slices = torch.load(
21+
self.processed_paths[0],
22+
map_location=torch.device('cpu')
23+
)
2424
else:
25-
if device == 'cpu':
26-
self.data, self.slices = torch.load(self.processed_paths[0], map_location=torch.device(device))
27-
else:
28-
self.data, self.slices = torch.load(self.processed_paths[0])
25+
self.data, self.slices = torch.load(self.processed_paths[0])
26+
2927

3028
@property
3129
def raw_file_names(self):
32-
'''
30+
"""
3331
The name of the files in the self.raw_dir folder
3432
that must be present in order to skip downloading.
35-
'''
33+
"""
3634
return []
3735

3836
def download(self):
39-
'''
37+
"""
4038
Download required data files; to be implemented
41-
'''
39+
"""
4240
pass
4341

4442
@property
@@ -47,11 +45,11 @@ def processed_dir(self):
4745

4846
@property
4947
def processed_file_names(self):
50-
'''
48+
"""
5149
The name of the files in the self.processed_dir
5250
folder that must be present in order to skip processing.
53-
'''
51+
"""
5452
return ["data.pt"]
5553

5654
class LargeStructureDataset(InMemoryDataset):
57-
pass
55+
pass

matdeeplearn/preprocessor/helpers.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ def one_hot_degree(data, max_degree, in_degree=False, cat=True):
4646

4747

4848
class GaussianSmearing(torch.nn.Module):
49-
'''
49+
"""
5050
slightly edited version from pytorch geometric to create edge from gaussian basis
51-
'''
52-
def __init__(self, start=0.0, stop=5.0, resolution=50, width=0.05, device='cpu', **kwargs):
51+
"""
52+
def __init__(self, start=0.0, stop=5.0, resolution=50, width=0.05, device="cpu", **kwargs):
5353
super(GaussianSmearing, self).__init__()
5454
offset = torch.linspace(start, stop, resolution, device=device)
5555
# self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
@@ -93,22 +93,21 @@ def get_ranges(dataset, descriptor_label):
9393

9494
def clean_up(data_list, attr_list):
9595
if not attr_list:
96-
return data_list
96+
return
9797

98+
# check which attributes in the list are removable
99+
removable_attrs = [t for t in attr_list if t in data_list[0].to_dict()]
98100
for data in data_list:
99-
for attr in attr_list:
100-
try:
101-
delattr(data, attr)
102-
except:
103-
continue
101+
for attr in removable_attrs:
102+
delattr(data, attr)
104103

105104
def get_distances(
106105
positions: torch.Tensor,
107106
offsets: torch.Tensor,
108-
device: str = 'cpu',
107+
device: str = "cpu",
109108
mic: bool = True
110109
):
111-
'''
110+
"""
112111
Get pairwise atomic distances
113112
114113
Parameters
@@ -123,7 +122,7 @@ def get_distances(
123122
124123
mic: bool
125124
minimum image convention
126-
'''
125+
"""
127126

128127
# convert numpy array to torch tensors
129128
n_atoms = len(positions)
@@ -141,7 +140,7 @@ def get_distances(
141140
# this allows us to get the minimum self-loop distance
142141
# of an atom to itself in all other images
143142
origin_unit_cell_idx = 13
144-
# atomic_distances[:,:,origin_unit_cell_idx].fill_diagonal_(float('inf'))
143+
# atomic_distances[:,:,origin_unit_cell_idx].fill_diagonal_(float("inf"))
145144

146145
# get minimum
147146
min_atomic_distances, min_indices = torch.min(atomic_distances, dim=-1)
@@ -154,8 +153,8 @@ def get_distances(
154153
return min_atomic_distances, min_indices
155154

156155

157-
def get_pbc_cells(cell: torch.Tensor, offset_number: int, device: str = 'cpu'):
158-
'''
156+
def get_pbc_cells(cell: torch.Tensor, offset_number: int, device: str = "cpu"):
157+
"""
159158
Get the periodic boundary condition (PBC) offsets for a unit cell
160159
161160
Parameters
@@ -166,15 +165,15 @@ def get_pbc_cells(cell: torch.Tensor, offset_number: int, device: str = 'cpu'):
166165
the number of offsets for the unit cell
167166
if == 0: no PBC
168167
if == 1: 27-cell offsets (3x3x3)
169-
'''
168+
"""
170169

171170
_range = np.arange(-offset_number, offset_number+1)
172171
offsets = [list(x) for x in itertools.product(_range, _range, _range)]
173172
offsets = torch.tensor(offsets, device=device, dtype=torch.float)
174173
return offsets @ cell, offsets
175174

176175
def get_cutoff_distance_matrix(pos, cell, r, n_neighbors, device, image_selfloop, offset_number=1):
177-
'''
176+
"""
178177
get the distance matrix
179178
TODO: need to tune this for elongated structures
180179
@@ -192,7 +191,7 @@ def get_cutoff_distance_matrix(pos, cell, r, n_neighbors, device, image_selfloop
192191
193192
n_neighbors: int
194193
max number of neighbors to be considered
195-
'''
194+
"""
196195

197196
cells, cell_coors = get_pbc_cells(cell, offset_number, device=device)
198197
distance_matrix, min_indices = get_distances(pos, cells, device=device)
@@ -221,14 +220,14 @@ def get_cutoff_distance_matrix(pos, cell, r, n_neighbors, device, image_selfloop
221220
return cutoff_distance_matrix, cell_offsets
222221

223222
def add_selfloop(num_nodes, edge_indices, edge_weights, cutoff_distance_matrix, self_loop=True):
224-
'''
223+
"""
225224
add self loop (i, i) to graph structure
226225
227226
Parameters
228227
----------
229228
n_nodes: int
230229
number of nodes
231-
'''
230+
"""
232231

233232
if not self_loop:
234233
return edge_indices, edge_weights, (cutoff_distance_matrix != 0).int()
@@ -240,25 +239,25 @@ def add_selfloop(num_nodes, edge_indices, edge_weights, cutoff_distance_matrix,
240239
distance_matrix_masked = (cutoff_distance_matrix.fill_diagonal_(1) != 0).int()
241240
return edge_indices, edge_weights, distance_matrix_masked
242241

243-
def load_node_representation(node_representation='onehot'):
242+
def load_node_representation(node_representation="onehot"):
244243
node_rep_path = Path(__file__).parent
245244
default_reps = {
246-
'onehot': str(node_rep_path / './node_representations/onehot.csv')
245+
"onehot": str(node_rep_path / "./node_representations/onehot.csv")
247246
}
248247

249248
rep_file_path = node_representation
250249
if node_representation in default_reps:
251250
rep_file_path = default_reps[node_representation]
252251

253-
file_type = rep_file_path.split('.')[-1]
252+
file_type = rep_file_path.split(".")[-1]
254253
loaded_rep = None
255254

256-
if file_type == 'csv':
257-
loaded_rep = np.genfromtxt(rep_file_path, delimiter=',')
255+
if file_type == "csv":
256+
loaded_rep = np.genfromtxt(rep_file_path, delimiter=",")
258257
# TODO: need to check if typecasting to integer is needed
259258
loaded_rep = loaded_rep.astype(int)
260259

261-
elif file_type == 'json':
260+
elif file_type == "json":
262261
# TODO
263262
pass
264263

@@ -286,6 +285,6 @@ def generate_edge_features(input_data, edge_steps, r, device):
286285
if isinstance(input_data, Data):
287286
input_data = [input_data]
288287

289-
normalize_edge_cutoff(input_data, 'distance', r)
288+
normalize_edge_cutoff(input_data, "distance", r)
290289
for i, data in enumerate(input_data):
291-
input_data[i].edge_attr = distance_gaussian(input_data[i].edge_descriptor['distance'])
290+
input_data[i].edge_attr = distance_gaussian(input_data[i].edge_descriptor["distance"])

matdeeplearn/preprocessor/processor.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def process_data(dataset_config):
3232
node_representation = dataset_config.get("node_representation", "onehot")
3333
additional_attributes = dataset_config.get("additional_attributes", [])
3434
verbose: bool = dataset_config.get("verbose", True)
35+
device: str = dataset_config.get("device", "cpu")
3536

3637
processor = DataProcessor(
3738
root_path=root_path,
@@ -46,6 +47,7 @@ def process_data(dataset_config):
4647
node_representation=node_representation,
4748
additional_attributes=additional_attributes,
4849
verbose=verbose,
50+
device=device
4951
)
5052
processor.process()
5153

@@ -65,6 +67,7 @@ def __init__(
6567
node_representation: str = "onehot",
6668
additional_attributes: list = [],
6769
verbose: bool = True,
70+
device: str = "cpu",
6871
) -> None:
6972
"""
7073
create a DataProcessor that processes the raw data and save into data.pt file.
@@ -77,6 +80,9 @@ def __init__(
7780
target_file_path: str
7881
a path to a CSV file containing target y values
7982
83+
pt_path: str
84+
a path to the directory to which data.pt should be saved
85+
8086
r: float
8187
cutoff radius
8288
@@ -124,12 +130,9 @@ def __init__(
124130
self.node_representation = node_representation
125131
self.additional_attributes = additional_attributes
126132
self.verbose = verbose
133+
self.device = device
127134

128135
self.disable_tqdm = logging.root.level > logging.INFO
129-
self.device = "cpu"
130-
131-
def set_device(self, device):
132-
self.device = device
133136

134137
def src_check(self):
135138
if self.target_file_path:
@@ -195,7 +198,7 @@ def get_csv_additional_attributes(self, structure_id):
195198

196199
def json_wrap(self):
197200
"""
198-
all structures are saved to a single json file
201+
all structures are saved in a single json file
199202
"""
200203
logging.info("Reading one JSON file for multiple structures.")
201204

@@ -209,7 +212,7 @@ def json_wrap(self):
209212

210213
dict_structures = []
211214
y = []
212-
y_dim = 1
215+
y_dim = len(original_structures[0]["y"]) if isinstance(original_structures[0]["y"], list) else 1
213216

214217
logging.info("Converting data to standardized form for downstream processing.")
215218
for i, s in enumerate(tqdm(original_structures, disable=self.disable_tqdm)):
@@ -232,14 +235,13 @@ def json_wrap(self):
232235

233236
dict_structures.append(d)
234237

235-
if isinstance(s["y"], str):
236-
y.append(float(s["y"]))
237-
elif isinstance(s["y"], list):
238-
_y = [float(each) for each in s["y"]]
239-
y.append(_y)
240-
y_dim = len(_y)
241-
else:
242-
y.append(s["y"])
238+
# check y types
239+
_y = s["y"]
240+
if isinstance(_y, str):
241+
_y = float(_y)
242+
elif isinstance(_y, list):
243+
_y = [float(each) for each in _y]
244+
y.append(_y)
243245

244246
y = np.array(y).reshape(-1, y_dim)
245247
return dict_structures, y
@@ -296,7 +298,7 @@ def get_data_list(self, dict_structures, y):
296298
data.cell_offsets = cell_offsets
297299

298300
data.edge_descriptor = {}
299-
# data.edge_descriptor['mask'] = cd_matrix_masked
301+
# data.edge_descriptor["mask"] = cd_matrix_masked
300302
data.edge_descriptor["distance"] = edge_weights
301303
data.distances = edge_weights
302304
data.structure_id = [[structure_id] * len(data.y)]
@@ -314,4 +316,4 @@ def get_data_list(self, dict_structures, y):
314316

315317
clean_up(data_list, ["edge_descriptor"])
316318

317-
return data_list
319+
return data_list

0 commit comments

Comments
 (0)