-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathGCN.py
More file actions
83 lines (68 loc) · 2.56 KB
/
GCN.py
File metadata and controls
83 lines (68 loc) · 2.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC.
# Produced at the Lawrence Livermore National Laboratory.
# Written by the LBANN Research Team (B. Van Essen, et al.) listed in
# the CONTRIBUTORS file. See the top-level LICENSE file for details.
#
# LLNL-CODE-697807.
# All rights reserved.
#
# This file is part of LBANN: Livermore Big Artificial Neural Network
# Toolkit. For details, see http://software.llnl.gov/LBANN or
# https://github.com/LBANN and https://github.com/LLNL/LBANN.
#
# SPDX-License-Identifier: (Apache-2.0)
import torch
import torch.nn as nn
from DGraph.distributed.nccl._NCCLCommPlan import NCCLGraphCommPlan
from DGraph.utils.TimingReport import TimingReport
class ConvLayer(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvLayer, self).__init__()
self.conv = nn.Linear(in_channels, out_channels)
self.act = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.act(x)
return x
class CommAwareGCN(nn.Module):
"""
GNN model that uses NCCLGraphCommPlan for distributed gather-scatter.
"""
def __init__(self, in_channels, hidden_dims, num_classes, comm):
super(CommAwareGCN, self).__init__()
self.conv1 = ConvLayer(in_channels, hidden_dims)
self.conv2 = ConvLayer(hidden_dims, hidden_dims)
self.fc = nn.Linear(hidden_dims, num_classes)
self.comm = comm
def forward(
self,
node_features: torch.Tensor,
comm_plan: NCCLGraphCommPlan,
):
"""
Args:
node_features: Local node features (batch, num_local_nodes, features)
comm_plan: Pre-computed NCCLGraphCommPlan for gather-scatter
"""
TimingReport.start("Gather_1")
x = self.comm.gather(node_features, comm_plan=comm_plan)
TimingReport.stop("Gather_1")
TimingReport.start("Conv_1")
x = self.conv1(x)
TimingReport.stop("Conv_1")
TimingReport.start("Scatter_1")
x = self.comm.scatter(x, comm_plan=comm_plan)
TimingReport.stop("Scatter_1")
TimingReport.start("Gather_2")
x = self.comm.gather(x, comm_plan=comm_plan)
TimingReport.stop("Gather_2")
TimingReport.start("Conv_2")
x = self.conv2(x)
TimingReport.stop("Conv_2")
TimingReport.start("Scatter_2")
x = self.comm.scatter(x, comm_plan=comm_plan)
TimingReport.stop("Scatter_2")
TimingReport.start("Final_FC")
x = self.fc(x)
TimingReport.stop("Final_FC")
return x