Skip to content

Commit d31e226

Browse files
Unify RMSNorm code.
1 parent b79fd7d commit d31e226

3 files changed

Lines changed: 17 additions & 24 deletions

File tree

comfy/ldm/common_dit.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,21 @@
11
import torch
2+
import comfy.ops
23

34
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
45
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
56
padding_mode = "reflect"
67
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
78
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
89
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
10+
11+
try:
12+
rms_norm_torch = torch.nn.functional.rms_norm
13+
except:
14+
rms_norm_torch = None
15+
16+
def rms_norm(x, weight, eps=1e-6):
17+
if rms_norm_torch is not None:
18+
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
19+
else:
20+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
21+
return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)

comfy/ldm/flux/layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .math import attention, rope
88
import comfy.ops
9+
import comfy.ldm.common_dit
910

1011

1112
class EmbedND(nn.Module):
@@ -63,8 +64,7 @@ def __init__(self, dim: int, dtype=None, device=None, operations=None):
6364
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
6465

6566
def forward(self, x: Tensor):
66-
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
67-
return (x * rrms) * comfy.ops.cast_to(self.scale, dtype=x.dtype, device=x.device)
67+
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
6868

6969

7070
class QKNorm(torch.nn.Module):

comfy/ldm/modules/diffusionmodules/mmdit.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -355,29 +355,9 @@ def __init__(
355355
else:
356356
self.register_parameter("weight", None)
357357

358-
def _norm(self, x):
359-
"""
360-
Apply the RMSNorm normalization to the input tensor.
361-
Args:
362-
x (torch.Tensor): The input tensor.
363-
Returns:
364-
torch.Tensor: The normalized tensor.
365-
"""
366-
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
367-
368358
def forward(self, x):
369-
"""
370-
Forward pass through the RMSNorm layer.
371-
Args:
372-
x (torch.Tensor): The input tensor.
373-
Returns:
374-
torch.Tensor: The output tensor after applying RMSNorm.
375-
"""
376-
x = self._norm(x)
377-
if self.learnable_scale:
378-
return x * self.weight.to(device=x.device, dtype=x.dtype)
379-
else:
380-
return x
359+
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
360+
381361

382362

383363
class SwiGLUFeedForward(nn.Module):

0 commit comments

Comments
 (0)