66from torch_geometric .utils import remove_self_loops
77from matdeeplearn .preprocessor .helpers import compute_bond_angles , triplets
88from scipy .spatial .distance import cdist
9+ from contextlib import contextmanager
910
1011'''
1112here resides the transform classes needed for data processing
1617 The data object will be transformed before every access.
1718'''
1819
20+ TRANSFORM_REGISTRY = {}
21+
22+
23+ def register_transform (transform_name ):
24+ '''Registers a transform function for bookkeeping.'''
25+ def registered_transform (transform ):
26+ TRANSFORM_REGISTRY [transform_name ] = transform
27+ return transform
28+ return registered_transform
29+
1930
2031class GetY (object ):
2132 def __init__ (self , index = 0 ):
@@ -28,6 +39,7 @@ def __call__(self, data):
2839 return data
2940
3041
42+ @register_transform ("NumNodeTransform" )
3143class NumNodeTransform (object ):
3244 '''
3345 Adds the number of nodes to the data object
@@ -38,6 +50,7 @@ def __call__(self, data):
3850 return data
3951
4052
53+ @register_transform ("LineGraphMod" )
4154class LineGraphMod (object ):
4255 '''
4356 Adds line graph attributes to the data object
@@ -47,59 +60,37 @@ def __call__(self, data):
4760 # CODE FROM PYG LINEGRAPH TRANSFORM (DIRECTED)
4861 N = data .num_nodes
4962 edge_index , edge_attr = data .edge_index , data .edge_attr
50- (row , col ), edge_attr = coalesce (edge_index , edge_attr , N , N )
51-
52- i = torch .arange (row .size (0 ), dtype = torch .long , device = row .device )
53- count = scatter_add (torch .ones_like (row ), row , dim = 0 ,
54- dim_size = data .num_nodes )
55- cumsum = torch .cat ([count .new_zeros (1 ), count .cumsum (0 )], dim = 0 )
56-
57- cols = [
58- i [cumsum [col [j ]]:cumsum [col [j ] + 1 ]]
59- for j in range (col .size (0 ))
60- ]
61- rows = [row .new_full ((c .numel (), ), j ) for j , c in enumerate (cols )]
62-
63- row , col = torch .cat (rows , dim = 0 ), torch .cat (cols , dim = 0 )
64-
65- data .edge_index_lg = torch .stack ([row , col ], dim = 0 )
66- data .x_lg = data .edge_attr
67- data .num_nodes_lg = edge_index .size (1 )
68-
69- # CUSTOM CODE FOR CALCULATING EDGE ATTRIBUTES
70- edge_attr_lg = torch .zeros (
71- (data .edge_index_lg .shape [1 ], 1 ), device = 'cuda' )
63+ _ , edge_attr = coalesce (edge_index , edge_attr , N , N )
7264
7365 # compute bond angles
7466 angles , idx_kj , idx_ji = compute_bond_angles (
7567 data .pos , data .cell_offsets , data .edge_index , data .num_nodes )
7668 triplet_pairs = torch .stack ([idx_kj , idx_ji ], dim = 0 )
7769
78- # move triplets and edges to CPU for sklearn based calculation
79- match_indices = torch .Tensor (
80- np .where (cdist (data .edge_index_lg .T .cpu (), triplet_pairs .T .cpu ()) == 0 )[
81- 0 ].reshape (- 1 , 1 )
82- ).type (torch .long )
70+ data .edge_index_lg = triplet_pairs
71+ data .x_lg = data .edge_attr
72+ data .num_nodes_lg = edge_index .size (1 )
8373
8474 # assign bond angles to edge attributes
85- edge_attr_lg [ match_indices . squeeze ( - 1 )] = angles .reshape (- 1 , 1 )
75+ data . edge_attr_lg = angles .reshape (- 1 , 1 )
8676
87- data .edge_attr_lg = edge_attr_lg
88-
8977 return data
9078
79+
80+ @register_transform ("ToFloat" )
9181class ToFloat (object ):
9282 '''
9383 Convert non-int attributes to float
9484 '''
85+
9586 def __call__ (self , data ):
9687 data .x = data .x .float ()
9788 data .x_lg = data .x_lg .float ()
98-
89+
9990 data .distances = data .distances .float ()
10091 data .pos = data .pos .float ()
10192
10293 data .edge_attr = data .edge_attr .float ()
10394 data .edge_attr_lg = data .edge_attr_lg .float ()
10495
105- return data
96+ return data
0 commit comments