22import re
33from typing import Annotated , Callable
44
5+ import torch
56import torch .nn as nn
67from pydantic import BaseModel , Field
78
@@ -31,7 +32,7 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
3132 r"transformer\.wte\.weight" : (nn .init .normal_ , {"mean" : 0.0 , "std" : 1 }),
3233 # lm head weights
3334 r"transformer\.lm_head\.weight" : (
34- nn . init . trunc_normal_ ,
35+ trunc_normal_ ,
3536 {
3637 "mean" : 0.0 ,
3738 "std" : 1 / math .sqrt (n_embd ),
@@ -41,7 +42,7 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
4142 ),
4243 # qkv projections
4344 r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.weight" : (
44- nn . init . trunc_normal_ ,
45+ trunc_normal_ ,
4546 {
4647 "mean" : 0.0 ,
4748 "std" : 0.02 ,
@@ -51,7 +52,7 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
5152 ),
5253 # final attention projection in attention block
5354 r"transformer\.h\.\d+\.attn\.c_proj\.weight" : (
54- nn . init . trunc_normal_ ,
55+ trunc_normal_ ,
5556 {
5657 "mean" : 0.0 ,
5758 "std" : (
@@ -65,7 +66,7 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
6566 ),
6667 # SwiGLU
6768 r"transformer\.h\.\d+\.mlp\.(W)\.weight" : (
68- nn . init . trunc_normal_ ,
69+ trunc_normal_ ,
6970 {
7071 "mean" : 0.0 ,
7172 "std" : 0.02 ,
@@ -74,7 +75,7 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
7475 },
7576 ),
7677 r"transformer\.h\.\d+\.mlp\.(V|W_2)\.weight" : (
77- nn . init . trunc_normal_ ,
78+ trunc_normal_ ,
7879 {
7980 "mean" : 0.0 ,
8081 "std" : (
@@ -132,3 +133,37 @@ def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable
132133 raise ValueError (
133134 f"Regex { k } did not match any FQNs. The model specification probably does not match LLama3."
134135 )
136+
137+
138+ def trunc_normal_ (
139+ tensor : torch .Tensor ,
140+ mean : float = 0.0 ,
141+ std : float = 1.0 ,
142+ a : float = - 2.0 ,
143+ b : float = 2.0 ,
144+ ):
145+ """
146+ Fills the input tensor with values sampled from a truncated normal distribution.
147+ Values are drawn from a normal distribution with the given mean and standard
148+ deviation. Any sampled values outside the range defined by a and b are resampled
149+ until they fall within the bounds.
150+
151+ To avoid numerical instability in torch.nn.init.trunc_normal_, the initialization
152+ is always performed using float32 precision. The result is then cast back to the
153+ original data type of the input tensor.
154+
155+ Args:
156+ tensor: an n dimensional torch Tensor
157+ mean: the mean of the normal distribution
158+ std: the standard deviation of the normal distribution
159+ a: the lower bound for truncation
160+ b: the upper bound for truncation
161+
162+ Returns:
163+ The input tensor filled with values from the truncated normal distribution.
164+ """
165+ # This function is copied from from Meta's open-source project TorchTitan,
166+ # licensed under the BSD 3-Clause License.
167+ tmp = tensor .float ()
168+ nn .init .trunc_normal_ (tmp , mean = mean , std = std , a = a , b = b )
169+ tensor .copy_ (tmp )
0 commit comments