99from comfy .ldm .modules .attention import optimized_attention
1010from comfy .ldm .flux .layers import EmbedND
1111from comfy .ldm .flux .math import apply_rope
12+ from comfy .ldm .modules .diffusionmodules .mmdit import RMSNorm
1213import comfy .ldm .common_dit
1314import comfy .model_management
1415
16+
1517def sinusoidal_embedding_1d (dim , position ):
1618 # preprocess
1719 assert dim % 2 == 0
@@ -25,25 +27,6 @@ def sinusoidal_embedding_1d(dim, position):
2527 return x
2628
2729
28- class WanRMSNorm (nn .Module ):
29-
30- def __init__ (self , dim , eps = 1e-5 , device = None , dtype = None ):
31- super ().__init__ ()
32- self .dim = dim
33- self .eps = eps
34- self .weight = nn .Parameter (torch .ones (dim , device = device , dtype = dtype ))
35-
36- def forward (self , x ):
37- r"""
38- Args:
39- x(Tensor): Shape [B, L, C]
40- """
41- return self ._norm (x .float ()).type_as (x ) * comfy .model_management .cast_to (self .weight , dtype = x .dtype , device = x .device )
42-
43- def _norm (self , x ):
44- return x * torch .rsqrt (x .pow (2 ).mean (dim = - 1 , keepdim = True ) + self .eps )
45-
46-
4730class WanSelfAttention (nn .Module ):
4831
4932 def __init__ (self ,
@@ -66,8 +49,8 @@ def __init__(self,
6649 self .k = operation_settings .get ("operations" ).Linear (dim , dim , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
6750 self .v = operation_settings .get ("operations" ).Linear (dim , dim , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
6851 self .o = operation_settings .get ("operations" ).Linear (dim , dim , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
69- self .norm_q = WanRMSNorm (dim , eps = eps , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" )) if qk_norm else nn .Identity ()
70- self .norm_k = WanRMSNorm (dim , eps = eps , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" )) if qk_norm else nn .Identity ()
52+ self .norm_q = RMSNorm (dim , eps = eps , elementwise_affine = True , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" )) if qk_norm else nn .Identity ()
53+ self .norm_k = RMSNorm (dim , eps = eps , elementwise_affine = True , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" )) if qk_norm else nn .Identity ()
7154
7255 def forward (self , x , freqs ):
7356 r"""
@@ -131,7 +114,7 @@ def __init__(self,
131114 self .k_img = operation_settings .get ("operations" ).Linear (dim , dim , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
132115 self .v_img = operation_settings .get ("operations" ).Linear (dim , dim , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
133116 # self.alpha = nn.Parameter(torch.zeros((1, )))
134- self .norm_k_img = WanRMSNorm (dim , eps = eps , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" )) if qk_norm else nn .Identity ()
117+ self .norm_k_img = RMSNorm (dim , eps = eps , elementwise_affine = True , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" )) if qk_norm else nn .Identity ()
135118
136119 def forward (self , x , context ):
137120 r"""
0 commit comments