Skip to content

Commit 3434e93

Browse files
committed
refactor: Llama3 weight init always uses now fp32 trunc normal and optionally casts it down to the model dtype. This is improves stability of the weight init.
1 parent e08e56f commit 3434e93

1 file changed

Lines changed: 40 additions & 5 deletions

File tree

src/modalities/models/gpt2/llama3_like_initialization.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import re
33
from typing import Annotated, Callable
44

5+
import torch
56
import torch.nn as nn
67
from pydantic import BaseModel, Field
78

@@ -31,7 +32,7 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
3132
r"transformer\.wte\.weight": (nn.init.normal_, {"mean": 0.0, "std": 1}),
3233
# lm head weights
3334
r"transformer\.lm_head\.weight": (
34-
nn.init.trunc_normal_,
35+
trunc_normal_,
3536
{
3637
"mean": 0.0,
3738
"std": 1 / math.sqrt(n_embd),
@@ -41,7 +42,7 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
4142
),
4243
# qkv projections
4344
r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.weight": (
44-
nn.init.trunc_normal_,
45+
trunc_normal_,
4546
{
4647
"mean": 0.0,
4748
"std": 0.02,
@@ -51,7 +52,7 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
5152
),
5253
# final attention projection in attention block
5354
r"transformer\.h\.\d+\.attn\.c_proj\.weight": (
54-
nn.init.trunc_normal_,
55+
trunc_normal_,
5556
{
5657
"mean": 0.0,
5758
"std": (
@@ -65,7 +66,7 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
6566
),
6667
# SwiGLU
6768
r"transformer\.h\.\d+\.mlp\.(W)\.weight": (
68-
nn.init.trunc_normal_,
69+
trunc_normal_,
6970
{
7071
"mean": 0.0,
7172
"std": 0.02,
@@ -74,7 +75,7 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
7475
},
7576
),
7677
r"transformer\.h\.\d+\.mlp\.(V|W_2)\.weight": (
77-
nn.init.trunc_normal_,
78+
trunc_normal_,
7879
{
7980
"mean": 0.0,
8081
"std": (
@@ -132,3 +133,37 @@ def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable
132133
raise ValueError(
133134
f"Regex {k} did not match any FQNs. The model specification probably does not match LLama3."
134135
)
136+
137+
138+
def trunc_normal_(
139+
tensor: torch.Tensor,
140+
mean: float = 0.0,
141+
std: float = 1.0,
142+
a: float = -2.0,
143+
b: float = 2.0,
144+
):
145+
"""
146+
Fills the input tensor with values sampled from a truncated normal distribution.
147+
Values are drawn from a normal distribution with the given mean and standard
148+
deviation. Any sampled values outside the range defined by a and b are resampled
149+
until they fall within the bounds.
150+
151+
To avoid numerical instability in torch.nn.init.trunc_normal_, the initialization
152+
is always performed using float32 precision. The result is then cast back to the
153+
original data type of the input tensor.
154+
155+
Args:
156+
tensor: an n dimensional torch Tensor
157+
mean: the mean of the normal distribution
158+
std: the standard deviation of the normal distribution
159+
a: the lower bound for truncation
160+
b: the upper bound for truncation
161+
162+
Returns:
163+
The input tensor filled with values from the truncated normal distribution.
164+
"""
165+
# This function is copied from from Meta's open-source project TorchTitan,
166+
# licensed under the BSD 3-Clause License.
167+
tmp = tensor.float()
168+
nn.init.trunc_normal_(tmp, mean=mean, std=std, a=a, b=b)
169+
tensor.copy_(tmp)

0 commit comments

Comments
 (0)