Skip to content

Commit fa62287

Browse files
More code reuse in wan.
Fix bug when changing the compute dtype on wan.
1 parent 0844998 commit fa62287

2 files changed

Lines changed: 6 additions & 23 deletions

File tree

comfy/ldm/wan/model.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
from comfy.ldm.modules.attention import optimized_attention
1010
from comfy.ldm.flux.layers import EmbedND
1111
from comfy.ldm.flux.math import apply_rope
12+
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
1213
import comfy.ldm.common_dit
1314
import comfy.model_management
1415

16+
1517
def sinusoidal_embedding_1d(dim, position):
1618
# preprocess
1719
assert dim % 2 == 0
@@ -25,25 +27,6 @@ def sinusoidal_embedding_1d(dim, position):
2527
return x
2628

2729

28-
class WanRMSNorm(nn.Module):
29-
30-
def __init__(self, dim, eps=1e-5, device=None, dtype=None):
31-
super().__init__()
32-
self.dim = dim
33-
self.eps = eps
34-
self.weight = nn.Parameter(torch.ones(dim, device=device, dtype=dtype))
35-
36-
def forward(self, x):
37-
r"""
38-
Args:
39-
x(Tensor): Shape [B, L, C]
40-
"""
41-
return self._norm(x.float()).type_as(x) * comfy.model_management.cast_to(self.weight, dtype=x.dtype, device=x.device)
42-
43-
def _norm(self, x):
44-
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
45-
46-
4730
class WanSelfAttention(nn.Module):
4831

4932
def __init__(self,
@@ -66,8 +49,8 @@ def __init__(self,
6649
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
6750
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
6851
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
69-
self.norm_q = WanRMSNorm(dim, eps=eps, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
70-
self.norm_k = WanRMSNorm(dim, eps=eps, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
52+
self.norm_q = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
53+
self.norm_k = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
7154

7255
def forward(self, x, freqs):
7356
r"""
@@ -131,7 +114,7 @@ def __init__(self,
131114
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
132115
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
133116
# self.alpha = nn.Parameter(torch.zeros((1, )))
134-
self.norm_k_img = WanRMSNorm(dim, eps=eps, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
117+
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
135118

136119
def forward(self, x, context):
137120
r"""

comfy/model_patcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
639639
mem_counter += module_mem
640640
load_completely.append((module_mem, n, m, params))
641641

642-
if cast_weight:
642+
if cast_weight and hasattr(m, "comfy_cast_weights"):
643643
m.prev_comfy_cast_weights = m.comfy_cast_weights
644644
m.comfy_cast_weights = True
645645

0 commit comments

Comments
 (0)