-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathlayers.py
More file actions
218 lines (193 loc) · 7.68 KB
/
layers.py
File metadata and controls
218 lines (193 loc) · 7.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC.
# Produced at the Lawrence Livermore National Laboratory.
# Written by the LBANN Research Team (B. Van Essen, et al.) listed in
# the CONTRIBUTORS file. See the top-level LICENSE file for details.
#
# LLNL-CODE-697807.
# All rights reserved.
#
# This file is part of LBANN: Livermore Big Artificial Neural Network
# Toolkit. For details, see http://software.llnl.gov/LBANN or
# https://github.com/LBANN and https://github.com/LLNL/LBANN.
#
# SPDX-License-Identifier: (Apache-2.0)
from typing import Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from typing import Optional
from DGraph.Communicator import Communicator
from dist_utils import SingleProcessDummyCommunicator
# class MLPSiLuWithRecompute(nn.Module):
class MeshGraphMLP(nn.Module):
"""MLP for graph processing"""
def __init__(
self,
input_dim: int,
output_dim: int,
hidden_dim: int = 512,
hidden_layers: int = 1,
activation_fn: nn.Module = nn.SiLU(),
norm_type: str = "LayerNorm",
):
"""
Initializes a MeshGraphMLP instance.
Args:
input_dim (int): The dimensionality of the input features.
output_dim (int): The dimensionality of the output features.
hidden_dim (int, optional): The dimensionality of the hidden layers. Defaults to 512.
hidden_layers (int, optional): The number of hidden layers. Defaults to 1.
activation_fn (nn.Module, optional): The activation function to use. Defaults to nn.SiLU().
norm_type (str, optional): The type of normalization to apply. Defaults to "LayerNorm".
"""
super(MeshGraphMLP, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
norm_layer = getattr(nn, norm_type)
layers = [
nn.Linear(input_dim, hidden_dim),
activation_fn,
]
for _ in range(hidden_layers - 1):
layers += [
nn.Linear(hidden_dim, hidden_dim),
activation_fn,
]
layers.append(nn.Linear(hidden_dim, output_dim))
layers.append(norm_layer(output_dim))
self._model = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Compute the MLP
Args:
x: Node or edge features
Returns:
The transformed tensor
"""
return self._model(x)
class MeshNodeBlock(nn.Module):
"""Node block for mesh processing. Used in GraphCast and MeshGraphNet."""
def __init__(
self,
input_node_dim: int,
input_edge_dim: int,
output_node_dim: int,
comm: Union[Communicator, SingleProcessDummyCommunicator],
hidden_dim: int = 512,
num_hidden_layers: int = 1,
aggregation_type: str = "sum",
):
"""
Initializes a MeshNodeBlock instance.
Args:
input_node_dim (int): The dimensionality of the input node features.
input_edge_dim (int): The dimensionality of the input edge features.
output_node_dim (int): The dimensionality of the output node features.
comm (CommunicatorBase): The communicator to use for distributed training.
hidden_dim (int, optional): The dimensionality of the hidden layers. Defaults to 512.
aggregation_type (str, optional): The type of aggregation to use. Defaults to "sum".
"""
super(MeshNodeBlock, self).__init__()
assert aggregation_type in ["sum"], "Only sum aggregation is supported for now."
self.aggregation_type = aggregation_type
self.comm = comm
self.mesh_mlp = MeshGraphMLP(
input_dim=input_node_dim + input_edge_dim,
output_dim=output_node_dim,
hidden_dim=hidden_dim,
hidden_layers=num_hidden_layers,
)
def forward(
self,
node_features: torch.Tensor,
edge_features: torch.Tensor,
src_indices: torch.Tensor,
rank_mapping: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute the node block
Args:
node_features: The node features
edge_features: The edge features
src_indices: The source indices
dst_indices: The destination indices
Returns:
The updated node features
"""
# Sum all the edge features for each node
num_local_nodes = node_features.shape[0]
# TODO: This can be optimized by a fused gather-scatter operation - S.Z
aggregated_edge_features = self.comm.scatter(
edge_features, src_indices, rank_mapping, num_local_nodes
)
# Concatenate the node and edge features
x = torch.cat([node_features, aggregated_edge_features], dim=-1)
# Apply the MLP
node_features_new = self.mesh_mlp(x) + node_features
return node_features_new
class MeshEdgeBlock(nn.Module):
"""Edge block for mesh processing. Used in GraphCast and MeshGraphNet."""
def __init__(
self,
input_src_node_dim: int,
input_dst_node_dim: int,
input_edge_dim: int,
output_edge_dim: int,
comm: Union[Communicator, SingleProcessDummyCommunicator],
hidden_dim: int = 512,
num_hidden_layers: int = 1,
aggregation_type: str = "sum",
):
"""
Args:
input_node_dim (int): The dimensionality of the input node features.
input_edge_dim (int): The dimensionality of the input edge features.
output_edge_dim (int): The dimensionality of the output edge features.
comm (CommunicatorBase): The communicator to use for distributed training.
hidden_dim (int, optional): The dimensionality of the hidden layers. Defaults to 512.
aggregation_type (str, optional): The type of aggregation to use. Defaults to "sum".
"""
# TODO: Add concat trick for edge features - S.Z
super(MeshEdgeBlock, self).__init__()
assert aggregation_type in ["sum"], "Only sum aggregation is supported for now."
self.aggregation_type = aggregation_type
self.comm = comm
self.mesh_mlp = MeshGraphMLP(
input_dim=input_src_node_dim + input_dst_node_dim + input_edge_dim,
output_dim=output_edge_dim,
hidden_dim=hidden_dim,
hidden_layers=num_hidden_layers,
)
def forward(
self,
src_node_features: torch.Tensor,
dst_node_features: torch.Tensor,
edge_features: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
src_rank_mapping: Optional[torch.Tensor] = None,
dst_rank_mapping: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute the edge block
Args:
node_features: The node features
edge_features: The edge features
src_indices: The source indices
dst_indices: The destination indices
Returns:
The updated edge features
"""
# Concatenate the source and destination node features with the edge features
src_node_features = self.comm.gather(
src_node_features, src_indices, src_rank_mapping
)
dst_node_features = self.comm.gather(
dst_node_features, dst_indices, dst_rank_mapping
)
concatenated_features = torch.cat(
[src_node_features, dst_node_features, edge_features], dim=-1
)
# Apply the MLP
edge_features_new = self.mesh_mlp(concatenated_features) + edge_features
return edge_features_new