@@ -34,7 +34,7 @@ def __init__(
3434 )
3535
3636 @final
37- def forward (self , x : torch .tensor ) -> Tuple [ Data , List [ str ]] :
37+ def forward (self , x : torch .tensor ) -> torch . tensor :
3838 """Construct nodes from raw node features.
3939
4040 Args:
@@ -45,9 +45,9 @@ def forward(self, x: torch.tensor) -> Tuple[Data, List[str]]:
4545 Returns:
4646 graph: a graph without edges
4747 """
48- graph = self ._construct_nodes (x = x )
48+ data = self ._construct_nodes (x = x )
4949
50- return graph
50+ return data
5151
5252 @property
5353 def _output_feature_names (self ) -> List [str ]:
@@ -110,7 +110,7 @@ def _define_output_feature_names(
110110 """
111111
112112 @abstractmethod
113- def _construct_nodes (self , x : torch .tensor ) -> Data :
113+ def _construct_nodes (self , x : torch .tensor ) -> torch . tensor :
114114 """Construct nodes from raw node features ´x´.
115115
116116 Args:
@@ -133,8 +133,8 @@ def _define_output_feature_names(
133133 ) -> List [str ]:
134134 return input_feature_names
135135
136- def _construct_nodes (self , x : torch .Tensor ) -> Data :
137- return Data ( x = x )
136+ def _construct_nodes (self , x : torch .Tensor ) -> torch . Tensor :
137+ return x
138138
139139
140140class PercentileClusters (NodeDefinition ):
@@ -195,7 +195,7 @@ def _get_indices_and_feature_names(
195195 new_feature_names .append ("counts" )
196196 return cluster_idx , summ_idx , new_feature_names
197197
198- def _construct_nodes (self , x : torch .Tensor ) -> Data :
198+ def _construct_nodes (self , x : torch .Tensor ) -> torch . Tensor :
199199 # Cast to Numpy
200200 x = x .numpy ()
201201 # Construct clusters with percentile-summarized features
@@ -219,7 +219,7 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
219219 ) # noqa
220220 raise AttributeError
221221
222- return Data ( x = torch .tensor (array ) )
222+ return torch .tensor (array )
223223
224224
225225class NodeAsDOMTimeSeries (NodeDefinition ):
@@ -270,7 +270,7 @@ def _define_output_feature_names(
270270 ) -> List [str ]:
271271 return input_feature_names + ["new_node_col" ]
272272
273- def _construct_nodes (self , x : torch .Tensor ) -> Data :
273+ def _construct_nodes (self , x : torch .Tensor ) -> torch . Tensor :
274274 """Construct nodes from raw node features ´x´."""
275275 # Cast to Numpy
276276 x = x .numpy ()
@@ -309,7 +309,7 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
309309 new_node_col [0 ] = 1
310310 x = np .column_stack ([x , new_node_col ])
311311
312- return Data ( x = torch .tensor (x ) )
312+ return torch .tensor (x )
313313
314314
315315class IceMixNodes (NodeDefinition ):
@@ -440,7 +440,7 @@ def _pulse_sampler(
440440
441441 return ids
442442
443- def _construct_nodes (self , x : torch .Tensor ) -> Data :
443+ def _construct_nodes (self , x : torch .Tensor ) -> torch . Tensor :
444444
445445 event_length = x .shape [0 ]
446446 if self .hlc_name is not None :
@@ -468,4 +468,4 @@ def _construct_nodes(self, x: torch.Tensor) -> Data:
468468 for idx , feature in enumerate (non_ice_features ):
469469 graph [:event_length , idx ] = x [ids , self .feature_indexes [feature ]]
470470
471- return Data ( x = graph )
471+ return graph
0 commit comments