@@ -46,10 +46,10 @@ def one_hot_degree(data, max_degree, in_degree=False, cat=True):
4646
4747
4848class GaussianSmearing (torch .nn .Module ):
49- '''
49+ """
5050 slightly edited version from pytorch geometric to create edge from gaussian basis
51- '''
52- def __init__ (self , start = 0.0 , stop = 5.0 , resolution = 50 , width = 0.05 , device = ' cpu' , ** kwargs ):
51+ """
52+ def __init__ (self , start = 0.0 , stop = 5.0 , resolution = 50 , width = 0.05 , device = " cpu" , ** kwargs ):
5353 super (GaussianSmearing , self ).__init__ ()
5454 offset = torch .linspace (start , stop , resolution , device = device )
5555 # self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
@@ -93,22 +93,21 @@ def get_ranges(dataset, descriptor_label):
9393
9494def clean_up (data_list , attr_list ):
9595 if not attr_list :
96- return data_list
96+ return
9797
98+ # check which attributes in the list are removable
99+ removable_attrs = [t for t in attr_list if t in data_list [0 ].to_dict ()]
98100 for data in data_list :
99- for attr in attr_list :
100- try :
101- delattr (data , attr )
102- except :
103- continue
101+ for attr in removable_attrs :
102+ delattr (data , attr )
104103
105104def get_distances (
106105 positions : torch .Tensor ,
107106 offsets : torch .Tensor ,
108- device : str = ' cpu' ,
107+ device : str = " cpu" ,
109108 mic : bool = True
110109):
111- '''
110+ """
112111 Get pairwise atomic distances
113112
114113 Parameters
@@ -123,7 +122,7 @@ def get_distances(
123122
124123 mic: bool
125124 minimum image convention
126- '''
125+ """
127126
128127 # convert numpy array to torch tensors
129128 n_atoms = len (positions )
@@ -141,7 +140,7 @@ def get_distances(
141140 # this allows us to get the minimum self-loop distance
142141 # of an atom to itself in all other images
143142 origin_unit_cell_idx = 13
144- # atomic_distances[:,:,origin_unit_cell_idx].fill_diagonal_(float(' inf' ))
143+ # atomic_distances[:,:,origin_unit_cell_idx].fill_diagonal_(float(" inf" ))
145144
146145 # get minimum
147146 min_atomic_distances , min_indices = torch .min (atomic_distances , dim = - 1 )
@@ -154,8 +153,8 @@ def get_distances(
154153 return min_atomic_distances , min_indices
155154
156155
157- def get_pbc_cells (cell : torch .Tensor , offset_number : int , device : str = ' cpu' ):
158- '''
156+ def get_pbc_cells (cell : torch .Tensor , offset_number : int , device : str = " cpu" ):
157+ """
159158 Get the periodic boundary condition (PBC) offsets for a unit cell
160159
161160 Parameters
@@ -166,15 +165,15 @@ def get_pbc_cells(cell: torch.Tensor, offset_number: int, device: str = 'cpu'):
166165 the number of offsets for the unit cell
167166 if == 0: no PBC
168167 if == 1: 27-cell offsets (3x3x3)
169- '''
168+ """
170169
171170 _range = np .arange (- offset_number , offset_number + 1 )
172171 offsets = [list (x ) for x in itertools .product (_range , _range , _range )]
173172 offsets = torch .tensor (offsets , device = device , dtype = torch .float )
174173 return offsets @ cell , offsets
175174
176175def get_cutoff_distance_matrix (pos , cell , r , n_neighbors , device , image_selfloop , offset_number = 1 ):
177- '''
176+ """
178177 get the distance matrix
179178 TODO: need to tune this for elongated structures
180179
@@ -192,7 +191,7 @@ def get_cutoff_distance_matrix(pos, cell, r, n_neighbors, device, image_selfloop
192191
193192 n_neighbors: int
194193 max number of neighbors to be considered
195- '''
194+ """
196195
197196 cells , cell_coors = get_pbc_cells (cell , offset_number , device = device )
198197 distance_matrix , min_indices = get_distances (pos , cells , device = device )
@@ -221,14 +220,14 @@ def get_cutoff_distance_matrix(pos, cell, r, n_neighbors, device, image_selfloop
221220 return cutoff_distance_matrix , cell_offsets
222221
223222def add_selfloop (num_nodes , edge_indices , edge_weights , cutoff_distance_matrix , self_loop = True ):
224- '''
223+ """
225224 add self loop (i, i) to graph structure
226225
227226 Parameters
228227 ----------
229228 n_nodes: int
230229 number of nodes
231- '''
230+ """
232231
233232 if not self_loop :
234233 return edge_indices , edge_weights , (cutoff_distance_matrix != 0 ).int ()
@@ -240,25 +239,25 @@ def add_selfloop(num_nodes, edge_indices, edge_weights, cutoff_distance_matrix,
240239 distance_matrix_masked = (cutoff_distance_matrix .fill_diagonal_ (1 ) != 0 ).int ()
241240 return edge_indices , edge_weights , distance_matrix_masked
242241
243- def load_node_representation (node_representation = ' onehot' ):
242+ def load_node_representation (node_representation = " onehot" ):
244243 node_rep_path = Path (__file__ ).parent
245244 default_reps = {
246- ' onehot' : str (node_rep_path / ' ./node_representations/onehot.csv' )
245+ " onehot" : str (node_rep_path / " ./node_representations/onehot.csv" )
247246 }
248247
249248 rep_file_path = node_representation
250249 if node_representation in default_reps :
251250 rep_file_path = default_reps [node_representation ]
252251
253- file_type = rep_file_path .split ('.' )[- 1 ]
252+ file_type = rep_file_path .split ("." )[- 1 ]
254253 loaded_rep = None
255254
256- if file_type == ' csv' :
257- loaded_rep = np .genfromtxt (rep_file_path , delimiter = ',' )
255+ if file_type == " csv" :
256+ loaded_rep = np .genfromtxt (rep_file_path , delimiter = "," )
258257 # TODO: need to check if typecasting to integer is needed
259258 loaded_rep = loaded_rep .astype (int )
260259
261- elif file_type == ' json' :
260+ elif file_type == " json" :
262261 # TODO
263262 pass
264263
@@ -286,6 +285,6 @@ def generate_edge_features(input_data, edge_steps, r, device):
286285 if isinstance (input_data , Data ):
287286 input_data = [input_data ]
288287
289- normalize_edge_cutoff (input_data , ' distance' , r )
288+ normalize_edge_cutoff (input_data , " distance" , r )
290289 for i , data in enumerate (input_data ):
291- input_data [i ].edge_attr = distance_gaussian (input_data [i ].edge_descriptor [' distance' ])
290+ input_data [i ].edge_attr = distance_gaussian (input_data [i ].edge_descriptor [" distance" ])
0 commit comments