1+ import numpy as np
2+ import ase
3+ from ase import io
4+ import torch
15import itertools
26from pathlib import Path
37
4- import ase
5- import numpy as np
68import torch
79import torch .nn .functional as F
8- from ase import io
10+ from torch_geometric . utils import dense_to_sparse , degree , add_self_loops
911from torch_geometric .data .data import Data
10- from torch_geometric .utils import add_self_loops , degree , dense_to_sparse
11-
1212
1313def threshold_sort (all_distances , r , n_neighbors ):
14- A = all_distances .clone ().detach ()
14+ # A = all_distances.clone().detach()
15+ A = all_distances
1516
1617 # set diagonal to zero to exclude self-loop distance
17- A .fill_diagonal_ (0 )
18+ # A.fill_diagonal_(0)
1819
1920 # keep n_neighbors only
2021 N = len (A ) - n_neighbors - 1
2122 if N > 0 :
2223 _ , indices = torch .topk (A , N )
23- A .scatter_ (
24- 1 ,
25- indices ,
26- torch .zeros (len (A ), len (A ), device = all_distances .device , dtype = torch .float ),
24+ A = torch .scatter (
25+ A ,
26+ 1 , indices , torch .zeros (len (A ), len (A ),
27+ device = all_distances .device ,
28+ dtype = torch .float )
2729 )
2830
2931 A [A > r ] = 0
3032 return A
3133
32-
3334def one_hot_degree (data , max_degree , in_degree = False , cat = True ):
3435 idx , x = data .edge_index [1 if in_degree else 0 ], data .x
3536 deg = degree (idx , data .num_nodes , dtype = torch .long )
@@ -48,10 +49,7 @@ class GaussianSmearing(torch.nn.Module):
4849 """
4950 slightly edited version from pytorch geometric to create edge from gaussian basis
5051 """
51-
52- def __init__ (
53- self , start = 0.0 , stop = 5.0 , resolution = 50 , width = 0.05 , device = "cpu" , ** kwargs
54- ):
52+ def __init__ (self , start = 0.0 , stop = 5.0 , resolution = 50 , width = 0.05 , device = "cpu" , ** kwargs ):
5553 super (GaussianSmearing , self ).__init__ ()
5654 offset = torch .linspace (start , stop , resolution , device = device )
5755 # self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
@@ -62,7 +60,6 @@ def forward(self, dist):
6260 dist = dist .unsqueeze (- 1 ) - self .offset .view (1 , - 1 )
6361 return torch .exp (self .coeff * torch .pow (dist , 2 ))
6462
65-
6663def normalize_edge (dataset , descriptor_label ):
6764 mean , std , feature_min , feature_max = get_ranges (dataset , descriptor_label )
6865
@@ -71,6 +68,9 @@ def normalize_edge(dataset, descriptor_label):
7168 data .edge_descriptor [descriptor_label ] - feature_min
7269 ) / (feature_max - feature_min )
7370
71+ def normalize_edge_cutoff (dataset , descriptor_label , r ):
72+ for data in dataset :
73+ data .edge_descriptor [descriptor_label ] = data .edge_descriptor [descriptor_label ] / r
7474
7575def get_ranges (dataset , descriptor_label ):
7676 mean = 0.0
@@ -91,42 +91,39 @@ def get_ranges(dataset, descriptor_label):
9191 std = std / len (dataset )
9292 return mean , std , feature_min , feature_max
9393
94-
9594def clean_up (data_list , attr_list ):
9695 if not attr_list :
97- return data_list
98-
96+ return
97+
98+ # check which attributes in the list are removable
99+ removable_attrs = [t for t in attr_list if t in data_list [0 ].to_dict ()]
99100 for data in data_list :
100- for attr in attr_list :
101- try :
102- delattr (data , attr )
103- except AttributeError :
104- continue
105-
101+ for attr in removable_attrs :
102+ delattr (data , attr )
106103
107104def get_distances (
108105 positions : torch .Tensor ,
109106 offsets : torch .Tensor ,
110107 device : str = "cpu" ,
111- mic : bool = True ,
108+ mic : bool = True
112109):
113110 """
114111 Get pairwise atomic distances
115112
116113 Parameters
117114 positions: torch.Tensor
118115 positions of atoms in a unit cell
119-
116+
120117 offsets: torch.Tensor
121118 offsets for the unit cell
122-
119+
123120 device: str
124121 torch device type
125-
122+
126123 mic: bool
127124 minimum image convention
128125 """
129-
126+
130127 # convert numpy array to torch tensors
131128 n_atoms = len (positions )
132129 n_cells = len (offsets )
@@ -143,16 +140,14 @@ def get_distances(
143140 # this allows us to get the minimum self-loop distance
144141 # of an atom to itself in all other images
145142 origin_unit_cell_idx = 13
146- atomic_distances [:, :, origin_unit_cell_idx ].fill_diagonal_ (float ("inf" ))
143+ # atomic_distances[:,:, origin_unit_cell_idx].fill_diagonal_(float("inf"))
147144
148145 # get minimum
149146 min_atomic_distances , min_indices = torch .min (atomic_distances , dim = - 1 )
150147 expanded_min_indices = min_indices .clone ().detach ()
151148
152149 atom_rij = pos1 - pos2
153- expanded_min_indices = expanded_min_indices [..., None , None ].expand (
154- - 1 , - 1 , 1 , atom_rij .size (3 )
155- )
150+ expanded_min_indices = expanded_min_indices [..., None , None ].expand (- 1 , - 1 , 1 , atom_rij .size (3 ))
156151 atom_rij = torch .gather (atom_rij , dim = 2 , index = expanded_min_indices ).squeeze ()
157152
158153 return min_atomic_distances , min_indices
@@ -161,7 +156,7 @@ def get_distances(
161156def get_pbc_cells (cell : torch .Tensor , offset_number : int , device : str = "cpu" ):
162157 """
163158 Get the periodic boundary condition (PBC) offsets for a unit cell
164-
159+
165160 Parameters
166161 cell: torch.Tensor
167162 unit cell vectors of ase.cell.Cell
@@ -172,25 +167,22 @@ def get_pbc_cells(cell: torch.Tensor, offset_number: int, device: str = "cpu"):
172167 if == 1: 27-cell offsets (3x3x3)
173168 """
174169
175- _range = np .arange (- offset_number , offset_number + 1 )
170+ _range = np .arange (- offset_number , offset_number + 1 )
176171 offsets = [list (x ) for x in itertools .product (_range , _range , _range )]
177172 offsets = torch .tensor (offsets , device = device , dtype = torch .float )
178173 return offsets @ cell , offsets
179174
180-
181- def get_cutoff_distance_matrix (
182- pos , cell , r , n_neighbors , device , image_selfloop , offset_number = 1
183- ):
175+ def get_cutoff_distance_matrix (pos , cell , r , n_neighbors , device , image_selfloop , offset_number = 1 ):
184176 """
185177 get the distance matrix
186178 TODO: need to tune this for elongated structures
187179
188180 Parameters
189181 ----------
190- pos: np.ndarray
182+ pos: np.ndarray
191183 positions of atoms in a unit cell
192184 get from crystal.get_positions()
193-
185+
194186 cell: np.ndarray
195187 unit cell of a ase Atoms object
196188
@@ -206,11 +198,11 @@ def get_cutoff_distance_matrix(
206198
207199 cutoff_distance_matrix = threshold_sort (distance_matrix , r , n_neighbors )
208200
209- if image_selfloop :
210- # output of threshold sort has diagonal == 0
211- # fill in the original values
212- self_loop_diag = distance_matrix .diagonal ()
213- cutoff_distance_matrix .diagonal ().copy_ (self_loop_diag )
201+ # if image_selfloop:
202+ # # output of threshold sort has diagonal == 0
203+ # # fill in the original values
204+ # self_loop_diag = distance_matrix.diagonal()
205+ # cutoff_distance_matrix.diagonal().copy_(self_loop_diag)
214206
215207 all_cell_offsets = cell_coors [torch .flatten (min_indices )]
216208 all_cell_offsets = all_cell_offsets .view (len (pos ), - 1 , 3 )
@@ -222,15 +214,12 @@ def get_cutoff_distance_matrix(
222214 # thus initialize a zero matrix of (M+N, 3) for cell offsets
223215 n_edges = torch .count_nonzero (cutoff_distance_matrix ).item ()
224216 cell_offsets = torch .zeros (n_edges + len (pos ), 3 , dtype = torch .float )
225- # get cells for edges except for self loops
217+ # get cells for edges except for self loops
226218 cell_offsets [:n_edges , :] = all_cell_offsets [cutoff_distance_matrix != 0 ]
227219
228220 return cutoff_distance_matrix , cell_offsets
229221
230-
231- def add_selfloop (
232- num_nodes , edge_indices , edge_weights , cutoff_distance_matrix , self_loop = True
233- ):
222+ def add_selfloop (num_nodes , edge_indices , edge_weights , cutoff_distance_matrix , self_loop = True ):
234223 """
235224 add self loop (i, i) to graph structure
236225
@@ -250,15 +239,16 @@ def add_selfloop(
250239 distance_matrix_masked = (cutoff_distance_matrix .fill_diagonal_ (1 ) != 0 ).int ()
251240 return edge_indices , edge_weights , distance_matrix_masked
252241
253-
254242def load_node_representation (node_representation = "onehot" ):
255243 node_rep_path = Path (__file__ ).parent
256- default_reps = {"onehot" : str (node_rep_path / "./node_representations/onehot.csv" )}
244+ default_reps = {
245+ "onehot" : str (node_rep_path / "./node_representations/onehot.csv" )
246+ }
257247
258248 rep_file_path = node_representation
259249 if node_representation in default_reps :
260250 rep_file_path = default_reps [node_representation ]
261-
251+
262252 file_type = rep_file_path .split ("." )[- 1 ]
263253 loaded_rep = None
264254
@@ -273,33 +263,28 @@ def load_node_representation(node_representation="onehot"):
273263
274264 return loaded_rep
275265
276-
277266def generate_node_features (input_data , n_neighbors , device ):
278267 node_reps = load_node_representation ()
279268 node_reps = torch .from_numpy (node_reps ).to (device )
280269 n_elements , n_features = node_reps .shape
281-
270+
282271 if isinstance (input_data , Data ):
283- input_data .x = node_reps [input_data .z - 1 ].view (- 1 , n_features )
284- return one_hot_degree (input_data , n_neighbors + 1 )
272+ input_data .x = node_reps [input_data .z - 1 ].view (- 1 ,n_features )
273+ return one_hot_degree (input_data , n_neighbors + 1 )
285274
286275 for i , data in enumerate (input_data ):
287276 # minus 1 as the reps are 0-indexed but atomic number starts from 1
288- data .x = node_reps [data .z - 1 ].view (- 1 , n_features )
277+ data .x = node_reps [data .z - 1 ].view (- 1 ,n_features )
289278
290279 for i , data in enumerate (input_data ):
291- input_data [i ] = one_hot_degree (data , n_neighbors + 1 )
292-
280+ input_data [i ] = one_hot_degree (data , n_neighbors + 1 )
293281
294- def generate_edge_features (input_data , edge_steps , device ):
282+ def generate_edge_features (input_data , edge_steps , r , device ):
295283 distance_gaussian = GaussianSmearing (0 , 1 , edge_steps , 0.2 , device = device )
296284
297285 if isinstance (input_data , Data ):
298- input_data .edge_attr = distance_gaussian (input_data .edge_descriptor ["distance" ])
299- return
286+ input_data = [input_data ]
300287
301- normalize_edge (input_data , "distance" )
288+ normalize_edge_cutoff (input_data , "distance" , r )
302289 for i , data in enumerate (input_data ):
303- input_data [i ].edge_attr = distance_gaussian (
304- input_data [i ].edge_descriptor ["distance" ]
305- )
290+ input_data [i ].edge_attr = distance_gaussian (input_data [i ].edge_descriptor ["distance" ])
0 commit comments