We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 422c182 commit f896062Copy full SHA for f896062
1 file changed
src/graphnet/models/components/layers.py
@@ -880,7 +880,11 @@ def forward(self, data: Data) -> Data:
880
"""Forward pass."""
881
x = data.x
882
num_nodes = data.num_nodes
883
- log_deg = torch.log10(degree(data.edge_index[0]) + 1)
+ log_deg = torch.log10(
884
+ degree(data.edge_index[0], num_nodes=num_nodes, dtype=data.x.dtype)
885
+ + 1
886
+ )
887
+ log_deg = log_deg.view(data.num_nodes, 1)
888
889
x_attn_residual = x # for first residual connection
890
e_values_in = data.get("edge_attr", None)
0 commit comments