55and can be passed to dataloaders during training and deployment.
66"""
77
8- from typing import List , Optional , Dict , Union , Tuple , Any , Callable
8+ from typing import List , Optional , Dict , Union , Any , Callable
99import torch
1010from numpy .random import Generator
1111import numpy as np
@@ -120,23 +120,6 @@ def _set_output_feature_names(
120120 self ._node_definition .set_output_feature_names (input_feature_names )
121121 return self ._node_definition ._output_feature_names
122122
123- def _create_data (
124- self , input_features : torch .Tensor
125- ) -> Tuple [Data , List [str ]]:
126- # Create graph & get new node feature names
127- data = self ._node_definition (input_features )
128- if self ._sort_by is not None :
129- data .x = data .x [data .x [:, self ._sort_by ].sort ()[1 ]]
130-
131- # Enforce dtype
132- data .x = data .x .type (self .dtype )
133-
134- # Assign edges
135- if self ._edge_definition is not None :
136- data = self ._edge_definition (data )
137-
138- return data
139-
140123 def forward ( # type: ignore
141124 self ,
142125 input_features : np .ndarray ,
@@ -178,6 +161,18 @@ def forward( # type: ignore
178161 data_path = data_path ,
179162 )
180163
164+ # Create graph & get new node feature names
165+ data = self ._node_definition (data .x )
166+ if self ._sort_by is not None :
167+ data .x = data .x [data .x [:, self ._sort_by ].sort ()[1 ]]
168+
169+ # Enforce dtype
170+ data .x = data .x .type (self .dtype )
171+
172+ # Assign edges
173+ if self ._edge_definition is not None :
174+ data = self ._edge_definition (data )
175+
181176 if self ._add_static_features :
182177 data = self ._add_features_individually (
183178 data ,
@@ -186,17 +181,6 @@ def forward( # type: ignore
186181
187182 return data
188183
189- def _forward_end (
190- self ,
191- data : Data ,
192- data_feature_names : List [str ],
193- ) -> Data :
194- """Add processing steps at the end of the forward pass."""
195- # Add original features as attributes
196- if self ._add_static_features :
197- data = self ._add_features_individually (data , data_feature_names )
198- return data
199-
200184 def _add_features_individually (
201185 self ,
202186 data : Data ,
0 commit comments