Skip to content

Commit edd54f4

Browse files
authored
Merge pull request #7 from Fung-Lab/bugfix/processing
bugfix: processor device, in-place ops and more
2 parents 0467914 + 35ea15f commit edd54f4

5 files changed

Lines changed: 102 additions & 95 deletions

File tree

data/test_data/processed/data.pt

2.1 MB
Binary file not shown.

matdeeplearn/common/data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def get_dataloader(
103103
batch_size: int,
104104
num_workers: int = 0,
105105
sampler=None,
106+
shuffle=True
106107
):
107108
"""
108109
Returns a single dataloader for a given dataset
@@ -124,7 +125,7 @@ def get_dataloader(
124125
loader = DataLoader(
125126
dataset,
126127
batch_size=batch_size,
127-
shuffle=(sampler is None),
128+
shuffle=shuffle,
128129
num_workers=num_workers,
129130
sampler=sampler,
130131
)
Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,34 @@
1-
import os
1+
import torch, os
22

3-
import torch
43
from torch_geometric.data import InMemoryDataset
54

6-
75
class StructureDataset(InMemoryDataset):
86
def __init__(
97
self,
10-
root,
11-
processed_data_path,
12-
transform=None,
13-
pre_transform=None,
8+
root,
9+
processed_data_path,
10+
transform=None,
11+
pre_transform=None,
1412
pre_filter=None,
13+
device=None
1514
):
1615
self.root = root
1716
self.processed_data_path = processed_data_path
18-
super(StructureDataset, self).__init__(
19-
root, transform, pre_transform, pre_filter
20-
)
21-
self.data, self.slices = torch.load(self.processed_paths[0])
22-
17+
super(StructureDataset, self).__init__(root, transform, pre_transform, pre_filter)
18+
19+
if not torch.cuda.is_available() or device == "cpu":
20+
self.data, self.slices = torch.load(
21+
self.processed_paths[0],
22+
map_location=torch.device('cpu')
23+
)
24+
else:
25+
self.data, self.slices = torch.load(self.processed_paths[0])
26+
27+
2328
@property
2429
def raw_file_names(self):
2530
"""
26-
The name of the files in the self.raw_dir folder
31+
The name of the files in the self.raw_dir folder
2732
that must be present in order to skip downloading.
2833
"""
2934
return []
@@ -41,11 +46,10 @@ def processed_dir(self):
4146
@property
4247
def processed_file_names(self):
4348
"""
44-
The name of the files in the self.processed_dir
49+
The name of the files in the self.processed_dir
4550
folder that must be present in order to skip processing.
4651
"""
4752
return ["data.pt"]
4853

49-
5054
class LargeStructureDataset(InMemoryDataset):
5155
pass

matdeeplearn/preprocessor/helpers.py

Lines changed: 55 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,36 @@
1+
import numpy as np
2+
import ase
3+
from ase import io
4+
import torch
15
import itertools
26
from pathlib import Path
37

4-
import ase
5-
import numpy as np
68
import torch
79
import torch.nn.functional as F
8-
from ase import io
10+
from torch_geometric.utils import dense_to_sparse, degree, add_self_loops
911
from torch_geometric.data.data import Data
10-
from torch_geometric.utils import add_self_loops, degree, dense_to_sparse
11-
1212

1313
def threshold_sort(all_distances, r, n_neighbors):
14-
A = all_distances.clone().detach()
14+
# A = all_distances.clone().detach()
15+
A = all_distances
1516

1617
# set diagonal to zero to exclude self-loop distance
17-
A.fill_diagonal_(0)
18+
# A.fill_diagonal_(0)
1819

1920
# keep n_neighbors only
2021
N = len(A) - n_neighbors - 1
2122
if N > 0:
2223
_, indices = torch.topk(A, N)
23-
A.scatter_(
24-
1,
25-
indices,
26-
torch.zeros(len(A), len(A), device=all_distances.device, dtype=torch.float),
24+
A = torch.scatter(
25+
A,
26+
1, indices, torch.zeros(len(A), len(A),
27+
device=all_distances.device,
28+
dtype=torch.float)
2729
)
2830

2931
A[A > r] = 0
3032
return A
3133

32-
3334
def one_hot_degree(data, max_degree, in_degree=False, cat=True):
3435
idx, x = data.edge_index[1 if in_degree else 0], data.x
3536
deg = degree(idx, data.num_nodes, dtype=torch.long)
@@ -48,10 +49,7 @@ class GaussianSmearing(torch.nn.Module):
4849
"""
4950
slightly edited version from pytorch geometric to create edge from gaussian basis
5051
"""
51-
52-
def __init__(
53-
self, start=0.0, stop=5.0, resolution=50, width=0.05, device="cpu", **kwargs
54-
):
52+
def __init__(self, start=0.0, stop=5.0, resolution=50, width=0.05, device="cpu", **kwargs):
5553
super(GaussianSmearing, self).__init__()
5654
offset = torch.linspace(start, stop, resolution, device=device)
5755
# self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
@@ -62,7 +60,6 @@ def forward(self, dist):
6260
dist = dist.unsqueeze(-1) - self.offset.view(1, -1)
6361
return torch.exp(self.coeff * torch.pow(dist, 2))
6462

65-
6663
def normalize_edge(dataset, descriptor_label):
6764
mean, std, feature_min, feature_max = get_ranges(dataset, descriptor_label)
6865

@@ -71,6 +68,9 @@ def normalize_edge(dataset, descriptor_label):
7168
data.edge_descriptor[descriptor_label] - feature_min
7269
) / (feature_max - feature_min)
7370

71+
def normalize_edge_cutoff(dataset, descriptor_label, r):
72+
for data in dataset:
73+
data.edge_descriptor[descriptor_label] = data.edge_descriptor[descriptor_label] / r
7474

7575
def get_ranges(dataset, descriptor_label):
7676
mean = 0.0
@@ -91,42 +91,39 @@ def get_ranges(dataset, descriptor_label):
9191
std = std / len(dataset)
9292
return mean, std, feature_min, feature_max
9393

94-
9594
def clean_up(data_list, attr_list):
9695
if not attr_list:
97-
return data_list
98-
96+
return
97+
98+
# check which attributes in the list are removable
99+
removable_attrs = [t for t in attr_list if t in data_list[0].to_dict()]
99100
for data in data_list:
100-
for attr in attr_list:
101-
try:
102-
delattr(data, attr)
103-
except AttributeError:
104-
continue
105-
101+
for attr in removable_attrs:
102+
delattr(data, attr)
106103

107104
def get_distances(
108105
positions: torch.Tensor,
109106
offsets: torch.Tensor,
110107
device: str = "cpu",
111-
mic: bool = True,
108+
mic: bool = True
112109
):
113110
"""
114111
Get pairwise atomic distances
115112
116113
Parameters
117114
positions: torch.Tensor
118115
positions of atoms in a unit cell
119-
116+
120117
offsets: torch.Tensor
121118
offsets for the unit cell
122-
119+
123120
device: str
124121
torch device type
125-
122+
126123
mic: bool
127124
minimum image convention
128125
"""
129-
126+
130127
# convert numpy array to torch tensors
131128
n_atoms = len(positions)
132129
n_cells = len(offsets)
@@ -143,16 +140,14 @@ def get_distances(
143140
# this allows us to get the minimum self-loop distance
144141
# of an atom to itself in all other images
145142
origin_unit_cell_idx = 13
146-
atomic_distances[:, :, origin_unit_cell_idx].fill_diagonal_(float("inf"))
143+
# atomic_distances[:,:,origin_unit_cell_idx].fill_diagonal_(float("inf"))
147144

148145
# get minimum
149146
min_atomic_distances, min_indices = torch.min(atomic_distances, dim=-1)
150147
expanded_min_indices = min_indices.clone().detach()
151148

152149
atom_rij = pos1 - pos2
153-
expanded_min_indices = expanded_min_indices[..., None, None].expand(
154-
-1, -1, 1, atom_rij.size(3)
155-
)
150+
expanded_min_indices = expanded_min_indices[..., None, None].expand(-1, -1, 1, atom_rij.size(3))
156151
atom_rij = torch.gather(atom_rij, dim=2, index=expanded_min_indices).squeeze()
157152

158153
return min_atomic_distances, min_indices
@@ -161,7 +156,7 @@ def get_distances(
161156
def get_pbc_cells(cell: torch.Tensor, offset_number: int, device: str = "cpu"):
162157
"""
163158
Get the periodic boundary condition (PBC) offsets for a unit cell
164-
159+
165160
Parameters
166161
cell: torch.Tensor
167162
unit cell vectors of ase.cell.Cell
@@ -172,25 +167,22 @@ def get_pbc_cells(cell: torch.Tensor, offset_number: int, device: str = "cpu"):
172167
if == 1: 27-cell offsets (3x3x3)
173168
"""
174169

175-
_range = np.arange(-offset_number, offset_number + 1)
170+
_range = np.arange(-offset_number, offset_number+1)
176171
offsets = [list(x) for x in itertools.product(_range, _range, _range)]
177172
offsets = torch.tensor(offsets, device=device, dtype=torch.float)
178173
return offsets @ cell, offsets
179174

180-
181-
def get_cutoff_distance_matrix(
182-
pos, cell, r, n_neighbors, device, image_selfloop, offset_number=1
183-
):
175+
def get_cutoff_distance_matrix(pos, cell, r, n_neighbors, device, image_selfloop, offset_number=1):
184176
"""
185177
get the distance matrix
186178
TODO: need to tune this for elongated structures
187179
188180
Parameters
189181
----------
190-
pos: np.ndarray
182+
pos: np.ndarray
191183
positions of atoms in a unit cell
192184
get from crystal.get_positions()
193-
185+
194186
cell: np.ndarray
195187
unit cell of a ase Atoms object
196188
@@ -206,11 +198,11 @@ def get_cutoff_distance_matrix(
206198

207199
cutoff_distance_matrix = threshold_sort(distance_matrix, r, n_neighbors)
208200

209-
if image_selfloop:
210-
# output of threshold sort has diagonal == 0
211-
# fill in the original values
212-
self_loop_diag = distance_matrix.diagonal()
213-
cutoff_distance_matrix.diagonal().copy_(self_loop_diag)
201+
# if image_selfloop:
202+
# # output of threshold sort has diagonal == 0
203+
# # fill in the original values
204+
# self_loop_diag = distance_matrix.diagonal()
205+
# cutoff_distance_matrix.diagonal().copy_(self_loop_diag)
214206

215207
all_cell_offsets = cell_coors[torch.flatten(min_indices)]
216208
all_cell_offsets = all_cell_offsets.view(len(pos), -1, 3)
@@ -222,15 +214,12 @@ def get_cutoff_distance_matrix(
222214
# thus initialize a zero matrix of (M+N, 3) for cell offsets
223215
n_edges = torch.count_nonzero(cutoff_distance_matrix).item()
224216
cell_offsets = torch.zeros(n_edges + len(pos), 3, dtype=torch.float)
225-
# get cells for edges except for self loops
217+
# get cells for edges except for self loops
226218
cell_offsets[:n_edges, :] = all_cell_offsets[cutoff_distance_matrix != 0]
227219

228220
return cutoff_distance_matrix, cell_offsets
229221

230-
231-
def add_selfloop(
232-
num_nodes, edge_indices, edge_weights, cutoff_distance_matrix, self_loop=True
233-
):
222+
def add_selfloop(num_nodes, edge_indices, edge_weights, cutoff_distance_matrix, self_loop=True):
234223
"""
235224
add self loop (i, i) to graph structure
236225
@@ -250,15 +239,16 @@ def add_selfloop(
250239
distance_matrix_masked = (cutoff_distance_matrix.fill_diagonal_(1) != 0).int()
251240
return edge_indices, edge_weights, distance_matrix_masked
252241

253-
254242
def load_node_representation(node_representation="onehot"):
255243
node_rep_path = Path(__file__).parent
256-
default_reps = {"onehot": str(node_rep_path / "./node_representations/onehot.csv")}
244+
default_reps = {
245+
"onehot": str(node_rep_path / "./node_representations/onehot.csv")
246+
}
257247

258248
rep_file_path = node_representation
259249
if node_representation in default_reps:
260250
rep_file_path = default_reps[node_representation]
261-
251+
262252
file_type = rep_file_path.split(".")[-1]
263253
loaded_rep = None
264254

@@ -273,33 +263,28 @@ def load_node_representation(node_representation="onehot"):
273263

274264
return loaded_rep
275265

276-
277266
def generate_node_features(input_data, n_neighbors, device):
278267
node_reps = load_node_representation()
279268
node_reps = torch.from_numpy(node_reps).to(device)
280269
n_elements, n_features = node_reps.shape
281-
270+
282271
if isinstance(input_data, Data):
283-
input_data.x = node_reps[input_data.z - 1].view(-1, n_features)
284-
return one_hot_degree(input_data, n_neighbors + 1)
272+
input_data.x = node_reps[input_data.z-1].view(-1,n_features)
273+
return one_hot_degree(input_data, n_neighbors+1)
285274

286275
for i, data in enumerate(input_data):
287276
# minus 1 as the reps are 0-indexed but atomic number starts from 1
288-
data.x = node_reps[data.z - 1].view(-1, n_features)
277+
data.x = node_reps[data.z-1].view(-1,n_features)
289278

290279
for i, data in enumerate(input_data):
291-
input_data[i] = one_hot_degree(data, n_neighbors + 1)
292-
280+
input_data[i] = one_hot_degree(data, n_neighbors+1)
293281

294-
def generate_edge_features(input_data, edge_steps, device):
282+
def generate_edge_features(input_data, edge_steps, r, device):
295283
distance_gaussian = GaussianSmearing(0, 1, edge_steps, 0.2, device=device)
296284

297285
if isinstance(input_data, Data):
298-
input_data.edge_attr = distance_gaussian(input_data.edge_descriptor["distance"])
299-
return
286+
input_data = [input_data]
300287

301-
normalize_edge(input_data, "distance")
288+
normalize_edge_cutoff(input_data, "distance", r)
302289
for i, data in enumerate(input_data):
303-
input_data[i].edge_attr = distance_gaussian(
304-
input_data[i].edge_descriptor["distance"]
305-
)
290+
input_data[i].edge_attr = distance_gaussian(input_data[i].edge_descriptor["distance"])

0 commit comments

Comments
 (0)