def inspect_mixed_precision(model: FSDPModule):
model.unshard()
for param in model.parameters(recurse=False):
assert param.dtype == torch.bfloat16
model.reshard()
This function is highly misleading. In the provided example, the Transformer model itself has no direct parameters, all parameters are contained within its submodules. As a result, the current loop only inspects the model’s direct parameters, which means it effectively checks nothing. I would expect it to inspect some actual params.
As stated above.
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
@dataclass
class ModelArgs:
n_layers: int = 2
vocab_size: int = 8
max_seq_len: int = 16
dim: int = 16
n_heads: int = 4
dropout_p: float = 0.1
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.dim % args.n_heads == 0
self.head_dim = args.dim // args.n_heads
self.n_heads = args.n_heads
self.dropout_p = args.dropout_p
self.resid_dropout = nn.Dropout(args.dropout_p)
self.wq = nn.Linear(args.dim, args.dim, bias=False)
self.wk = nn.Linear(args.dim, args.dim, bias=False)
self.wv = nn.Linear(args.dim, args.dim, bias=False)
self.wo = nn.Linear(args.dim, args.dim, bias=False)
def forward(self, x):
bsz, seq_len, _ = x.size()
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
queries = queries.view(bsz, seq_len, self.n_heads, self.head_dim)
keys = keys.view(bsz, seq_len, self.n_heads, self.head_dim)
values = values.view(bsz, seq_len, self.n_heads, self.head_dim)
queries = queries.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
keys = keys.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
values = values.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
output = F.scaled_dot_product_attention(
queries,
keys,
values,
None,
self.dropout_p if self.training else 0,
)
output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
return self.resid_dropout(self.wo(output))
def reset_parameters(self):
self.wq.reset_parameters()
self.wk.reset_parameters()
self.wv.reset_parameters()
self.wo.reset_parameters()
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout_p):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim)
self.gelu = nn.GELU()
self.w2 = nn.Linear(hidden_dim, dim)
self.resid_dropout = nn.Dropout(dropout_p)
def forward(self, x):
return self.resid_dropout(self.w2(self.gelu(self.w1(x))))
def reset_parameters(self):
self.w1.reset_parameters()
self.w2.reset_parameters()
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.attention_norm = nn.LayerNorm(args.dim)
self.attention = Attention(args)
self.ffn_norm = nn.LayerNorm(args.dim)
self.feed_forward = FeedForward(
args.dim, hidden_dim=4 * args.dim, dropout_p=args.dropout_p
)
def forward(self, x):
h = x + self.attention(self.attention_norm(x))
out = h + self.feed_forward(self.ffn_norm(h))
return out
def reset_parameters(self):
self.attention_norm.reset_parameters()
self.attention.reset_parameters()
self.ffn_norm.reset_parameters()
self.feed_forward.reset_parameters()
# A toy transformer model, partly inspired by the nanoGPT model:
# https://github.com/karpathy/nanoGPT.
class Transformer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.vocab_size is not None
assert args.max_seq_len is not None
self.model_args = args
self.max_seq_len = args.max_seq_len
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.pos_embeddings = nn.Embedding(args.max_seq_len, args.dim)
self.dropout = nn.Dropout(args.dropout_p)
self.layers = nn.ModuleList()
for _ in range(args.n_layers):
self.layers.append(TransformerBlock(args))
self.norm = nn.LayerNorm(args.dim)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
def forward(self, tokens):
_bsz, seq_len = tokens.size()
assert seq_len <= self.max_seq_len
h = self.tok_embeddings(tokens)
pos = torch.arange(0, seq_len, device=tokens.device)
p = self.pos_embeddings(pos) # positional embeddings of shape (seq_len, dim)
h = h + p
h = self.dropout(h)
for layer in self.layers:
h = layer(h)
h = self.norm(h)
output = self.output(h).float()
return output
def reset_parameters(self):
self.tok_embeddings.reset_parameters()
self.pos_embeddings.reset_parameters()
self.norm.reset_parameters()
self.output.reset_parameters()
model = Transformer(ModelArgs())
fsdp_kwargs = {
"mp_policy": MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
)
}
for layer in model.layers:
fully_shard(layer, **fsdp_kwargs)
fully_shard(model, **fsdp_kwargs)
# sharded parameters are float32
for name, param in model.named_parameters():
# print("local", name, param.dtype, param.device)
pass
# unsharded parameters are bfloat16
model.unshard()
for name, param in model.named_parameters(recurse=False):
print("unsharded", name, param.dtype, param.device)
model.reshard()
Your issue may already be reported!
Please search on the issue tracker before creating one.
Context
Your Environment
Expected Behavior
This function is highly misleading. In the provided example, the Transformer model itself has no direct parameters, all parameters are contained within its submodules. As a result, the current loop only inspects the model’s direct parameters, which means it effectively checks nothing. I would expect it to inspect some actual params.
Current Behavior
As stated above.
Possible Solution
Steps to Reproduce
or use this standalone script:
Failure Logs [if any]