|
14 | 14 | import torch |
15 | 15 | import torch.nn as nn |
16 | 16 | import torch.distributed as dist |
| 17 | +from DGraph.utils.TimingReport import TimingReport |
17 | 18 |
|
18 | 19 |
|
19 | 20 | class ConvLayer(nn.Module): |
@@ -54,24 +55,41 @@ def forward( |
54 | 55 | num_local_nodes = node_features.size(1) |
55 | 56 | _src_indices = edge_index[:, 0, :] |
56 | 57 | _dst_indices = edge_index[:, 1, :] |
| 58 | + TimingReport.start("pre-processing") |
57 | 59 | _src_rank_mappings = torch.cat( |
58 | 60 | [rank_mapping[0].unsqueeze(0), rank_mapping[0].unsqueeze(0)], dim=0 |
59 | 61 | ) |
60 | 62 | _dst_rank_mappings = torch.cat( |
61 | 63 | [rank_mapping[0].unsqueeze(0), rank_mapping[1].unsqueeze(0)], dim=0 |
62 | 64 | ) |
| 65 | + TimingReport.stop("pre-processing") |
| 66 | + TimingReport.start("Gather_1") |
63 | 67 | x = self.comm.gather( |
64 | 68 | node_features, _dst_indices, _dst_rank_mappings, cache=gather_cache |
65 | 69 | ) |
| 70 | + TimingReport.stop("Gather_1") |
| 71 | + TimingReport.start("Conv_1") |
66 | 72 | x = self.conv1(x) |
| 73 | + TimingReport.stop("Conv_1") |
| 74 | + TimingReport.start("Scatter_1") |
67 | 75 | x = self.comm.scatter( |
68 | 76 | x, _src_indices, _src_rank_mappings, num_local_nodes, cache=scatter_cache |
69 | 77 | ) |
| 78 | + TimingReport.stop("Scatter_1") |
| 79 | + TimingReport.start("Gather_2") |
70 | 80 | x = self.comm.gather(x, _dst_indices, _dst_rank_mappings, cache=gather_cache) |
| 81 | + TimingReport.stop("Gather_2") |
| 82 | + TimingReport.start("Conv_2") |
71 | 83 | x = self.conv2(x) |
| 84 | + TimingReport.stop("Conv_2") |
| 85 | + TimingReport.start("Scatter_2") |
72 | 86 | x = self.comm.scatter( |
73 | 87 | x, _src_indices, _src_rank_mappings, num_local_nodes, cache=scatter_cache |
74 | 88 | ) |
| 89 | + TimingReport.stop("Scatter_2") |
| 90 | + TimingReport.start("Final_FC") |
75 | 91 | x = self.fc(x) |
| 92 | + TimingReport.stop("Final_FC") |
| 93 | + |
76 | 94 | # x = self.softmax(x) |
77 | 95 | return x |
0 commit comments