Skip to content

Commit f896062

Browse files
committed
Fix log_deg shape
1 parent 422c182 commit f896062

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

src/graphnet/models/components/layers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,11 @@ def forward(self, data: Data) -> Data:
880880
"""Forward pass."""
881881
x = data.x
882882
num_nodes = data.num_nodes
883-
log_deg = torch.log10(degree(data.edge_index[0]) + 1)
883+
log_deg = torch.log10(
884+
degree(data.edge_index[0], num_nodes=num_nodes, dtype=data.x.dtype)
885+
+ 1
886+
)
887+
log_deg = log_deg.view(data.num_nodes, 1)
884888

885889
x_attn_residual = x # for first residual connection
886890
e_values_in = data.get("edge_attr", None)

0 commit comments

Comments
 (0)