44
55import torch
66import torch .nn as nn
7+ from einops import repeat
78
89from comfy .ldm .modules .attention import optimized_attention
9-
10+ from comfy .ldm .flux .layers import EmbedND
11+ from comfy .ldm .flux .math import apply_rope
1012
1113def sinusoidal_embedding_1d (dim , position ):
1214 # preprocess
@@ -21,45 +23,6 @@ def sinusoidal_embedding_1d(dim, position):
2123 return x
2224
2325
24- def rope_params (max_seq_len , dim , theta = 10000 ):
25- assert dim % 2 == 0
26- freqs = torch .outer (
27- torch .arange (max_seq_len ),
28- 1.0 / torch .pow (theta ,
29- torch .arange (0 , dim , 2 ).to (torch .float64 ).div (dim )))
30- freqs = torch .polar (torch .ones_like (freqs ), freqs )
31- return freqs
32-
33-
34- def rope_apply (x , grid_sizes , freqs ):
35- n , c = x .size (2 ), x .size (3 ) // 2
36-
37- # split freqs
38- freqs = freqs .split ([c - 2 * (c // 3 ), c // 3 , c // 3 ], dim = 1 )
39-
40- # loop over samples
41- output = []
42- for i , (f , h , w ) in enumerate (grid_sizes .tolist ()):
43- seq_len = f * h * w
44-
45- # precompute multipliers
46- x_i = torch .view_as_complex (x [i , :seq_len ].to (torch .float64 ).reshape (
47- seq_len , n , - 1 , 2 ))
48- freqs_i = torch .cat ([
49- freqs [0 ][:f ].view (f , 1 , 1 , - 1 ).expand (f , h , w , - 1 ),
50- freqs [1 ][:h ].view (1 , h , 1 , - 1 ).expand (f , h , w , - 1 ),
51- freqs [2 ][:w ].view (1 , 1 , w , - 1 ).expand (f , h , w , - 1 )
52- ], dim = - 1 ).reshape (seq_len , 1 , - 1 )
53-
54- # apply rotary embedding
55- x_i = torch .view_as_real (x_i * freqs_i ).flatten (2 )
56- x_i = torch .cat ([x_i , x [i , seq_len :]])
57-
58- # append to collection
59- output .append (x_i )
60- return torch .stack (output ).to (dtype = x .dtype )
61-
62-
6326class WanRMSNorm (nn .Module ):
6427
6528 def __init__ (self , dim , eps = 1e-5 , device = None , dtype = None ):
@@ -122,10 +85,11 @@ def qkv_fn(x):
12285 return q , k , v
12386
12487 q , k , v = qkv_fn (x )
88+ q , k = apply_rope (q , k , freqs )
12589
12690 x = optimized_attention (
127- q = rope_apply ( q , grid_sizes , freqs ) .view (b , s , n * d ),
128- k = rope_apply ( k , grid_sizes , freqs ) .view (b , s , n * d ),
91+ q = q .view (b , s , n * d ),
92+ k = k .view (b , s , n * d ),
12993 v = v ,
13094 heads = self .num_heads ,
13195 )
@@ -433,14 +397,8 @@ def __init__(self,
433397 # head
434398 self .head = Head (dim , out_dim , patch_size , eps , operation_settings = operation_settings )
435399
436- # buffers (don't use register_buffer otherwise dtype will be changed in to())
437- assert (dim % num_heads ) == 0 and (dim // num_heads ) % 2 == 0
438400 d = dim // num_heads
439- self .register_buffer ("freqs" , torch .cat ([
440- rope_params (1024 , d - 4 * (d // 6 )),
441- rope_params (1024 , 2 * (d // 6 )),
442- rope_params (1024 , 2 * (d // 6 ))
443- ], dim = 1 ), persistent = False )
401+ self .rope_embedder = EmbedND (dim = d , theta = 10000.0 , axes_dim = [d - 4 * (d // 6 ), 2 * (d // 6 ), 2 * (d // 6 )])
444402
445403 if model_type == 'i2v' :
446404 self .img_emb = MLPProj (1280 , dim , operation_settings = operation_settings )
@@ -453,6 +411,7 @@ def forward_orig(
453411 seq_len = None ,
454412 clip_fea = None ,
455413 y = None ,
414+ freqs = None ,
456415 ):
457416 r"""
458417 Forward pass through the diffusion model
@@ -477,10 +436,6 @@ def forward_orig(
477436 """
478437 if self .model_type == 'i2v' :
479438 assert clip_fea is not None and y is not None
480- # params
481- # device = self.patch_embedding.weight.device
482- # if self.freqs.device != device:
483- # self.freqs = self.freqs.to(device)
484439
485440 if y is not None :
486441 x = [torch .cat ([u , v ], dim = 0 ) for u , v in zip (x , y )]
@@ -523,7 +478,7 @@ def forward_orig(
523478 e = e0 ,
524479 seq_lens = seq_lens ,
525480 grid_sizes = grid_sizes ,
526- freqs = self . freqs ,
481+ freqs = freqs ,
527482 context = context ,
528483 context_lens = context_lens )
529484
@@ -538,8 +493,20 @@ def forward_orig(
538493 return x
539494 # return [u.float() for u in x]
540495
541- def forward (self , x , t , context , y = None , image = None , ** kwargs ):
542- return self .forward_orig ([x ], t , [context ], clip_fea = y , y = image )[0 ]
496+ def forward (self , x , timestep , context , y = None , image = None , ** kwargs ):
497+ bs , c , t , h , w = x .shape
498+ patch_size = self .patch_size
499+ t_len = ((t + (patch_size [0 ] // 2 )) // patch_size [0 ])
500+ h_len = ((h + (patch_size [1 ] // 2 )) // patch_size [1 ])
501+ w_len = ((w + (patch_size [2 ] // 2 )) // patch_size [2 ])
502+ img_ids = torch .zeros ((t_len , h_len , w_len , 3 ), device = x .device , dtype = x .dtype )
503+ img_ids [:, :, :, 0 ] = img_ids [:, :, :, 0 ] + torch .linspace (0 , t_len - 1 , steps = t_len , device = x .device , dtype = x .dtype ).reshape (- 1 , 1 , 1 )
504+ img_ids [:, :, :, 1 ] = img_ids [:, :, :, 1 ] + torch .linspace (0 , h_len - 1 , steps = h_len , device = x .device , dtype = x .dtype ).reshape (1 , - 1 , 1 )
505+ img_ids [:, :, :, 2 ] = img_ids [:, :, :, 2 ] + torch .linspace (0 , w_len - 1 , steps = w_len , device = x .device , dtype = x .dtype ).reshape (1 , 1 , - 1 )
506+ img_ids = repeat (img_ids , "t h w c -> b (t h w) c" , b = bs )
507+
508+ freqs = self .rope_embedder (img_ids ).movedim (1 , 2 )
509+ return self .forward_orig ([x ], timestep , [context ], clip_fea = y , y = image , freqs = freqs )[0 ]
543510
544511 def unpatchify (self , x , grid_sizes ):
545512 r"""
0 commit comments