Skip to content

Commit ab13000

Browse files
Do RMSNorm in native type.
1 parent ca4b8f3 commit ab13000

1 file changed

Lines changed: 1 addition & 3 deletions

File tree

comfy/ldm/flux/layers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,8 @@ def __init__(self, dim: int, dtype=None, device=None, operations=None):
6363
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
6464

6565
def forward(self, x: Tensor):
66-
x_dtype = x.dtype
67-
x = x.float()
6866
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
69-
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
67+
return (x * rrms) * comfy.ops.cast_to(self.scale, dtype=x.dtype, device=x.device)
7068

7169

7270
class QKNorm(torch.nn.Module):

0 commit comments

Comments
 (0)