99from comfy .ldm .modules .attention import optimized_attention
1010from comfy .ldm .flux .layers import EmbedND
1111from comfy .ldm .flux .math import apply_rope
12+ import comfy .ldm .common_dit
1213
1314def sinusoidal_embedding_1d (dim , position ):
1415 # preprocess
@@ -67,12 +68,10 @@ def __init__(self,
6768 self .norm_q = WanRMSNorm (dim , eps = eps , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" )) if qk_norm else nn .Identity ()
6869 self .norm_k = WanRMSNorm (dim , eps = eps , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" )) if qk_norm else nn .Identity ()
6970
70- def forward (self , x , seq_lens , grid_sizes , freqs ):
71+ def forward (self , x , freqs ):
7172 r"""
7273 Args:
7374 x(Tensor): Shape [B, L, num_heads, C / num_heads]
74- seq_lens(Tensor): Shape [B]
75- grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
7675 freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
7776 """
7877 b , s , n , d = * x .shape [:2 ], self .num_heads , self .head_dim
@@ -100,12 +99,11 @@ def qkv_fn(x):
10099
101100class WanT2VCrossAttention (WanSelfAttention ):
102101
103- def forward (self , x , context , context_lens ):
102+ def forward (self , x , context ):
104103 r"""
105104 Args:
106105 x(Tensor): Shape [B, L1, C]
107106 context(Tensor): Shape [B, L2, C]
108- context_lens(Tensor): Shape [B]
109107 """
110108 # compute query, key, value
111109 q = self .norm_q (self .q (x ))
@@ -134,12 +132,11 @@ def __init__(self,
134132 # self.alpha = nn.Parameter(torch.zeros((1, )))
135133 self .norm_k_img = WanRMSNorm (dim , eps = eps , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" )) if qk_norm else nn .Identity ()
136134
137- def forward (self , x , context , context_lens ):
135+ def forward (self , x , context ):
138136 r"""
139137 Args:
140138 x(Tensor): Shape [B, L1, C]
141139 context(Tensor): Shape [B, L2, C]
142- context_lens(Tensor): Shape [B]
143140 """
144141 context_img = context [:, :257 ]
145142 context = context [:, 257 :]
@@ -210,18 +207,13 @@ def forward(
210207 self ,
211208 x ,
212209 e ,
213- seq_lens ,
214- grid_sizes ,
215210 freqs ,
216211 context ,
217- context_lens ,
218212 ):
219213 r"""
220214 Args:
221215 x(Tensor): Shape [B, L, C]
222216 e(Tensor): Shape [B, 6, C]
223- seq_lens(Tensor): Shape [B], length of each sequence in batch
224- grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
225217 freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
226218 """
227219 # assert e.dtype == torch.float32
@@ -231,19 +223,19 @@ def forward(
231223
232224 # self-attention
233225 y = self .self_attn (
234- self .norm1 (x ) * (1 + e [1 ]) + e [0 ], seq_lens , grid_sizes ,
226+ self .norm1 (x ) * (1 + e [1 ]) + e [0 ],
235227 freqs )
236228
237229 x = x + y * e [2 ]
238230
239231 # cross-attention & ffn function
240- def cross_attn_ffn (x , context , context_lens , e ):
241- x = x + self .cross_attn (self .norm3 (x ), context , context_lens )
232+ def cross_attn_ffn (x , context , e ):
233+ x = x + self .cross_attn (self .norm3 (x ), context )
242234 y = self .ffn (self .norm2 (x ) * (1 + e [4 ]) + e [3 ])
243235 x = x + y * e [5 ]
244236 return x
245237
246- x = cross_attn_ffn (x , context , context_lens , e )
238+ x = cross_attn_ffn (x , context , e )
247239 return x
248240
249241
@@ -408,7 +400,6 @@ def forward_orig(
408400 x ,
409401 t ,
410402 context ,
411- seq_len = None ,
412403 clip_fea = None ,
413404 y = None ,
414405 freqs = None ,
@@ -417,12 +408,12 @@ def forward_orig(
417408 Forward pass through the diffusion model
418409
419410 Args:
420- x (List[ Tensor] ):
421- List of input video tensors, each with shape [C_in, F, H, W]
411+ x (Tensor):
412+ List of input video tensors with shape [B, C_in, F, H, W]
422413 t (Tensor):
423414 Diffusion timesteps tensor of shape [B]
424415 context (List[Tensor]):
425- List of text embeddings each with shape [L, C]
416+ List of text embeddings each with shape [B, L, C]
426417 seq_len (`int`):
427418 Maximum sequence length for positional encoding
428419 clip_fea (Tensor, *optional*):
@@ -438,36 +429,20 @@ def forward_orig(
438429 assert clip_fea is not None and y is not None
439430
440431 if y is not None :
441- x = [ torch .cat ([u , v ], dim = 0 ) for u , v in zip ( x , y )]
432+ x = torch .cat ([x , y ], dim = 0 )
442433
443434 # embeddings
444- x = [self .patch_embedding (u ) for u in x ]
445- grid_sizes = torch .stack (
446- [torch .tensor (u .shape [2 :], dtype = torch .long ) for u in x ])
447- x = [u .flatten (2 ).transpose (1 , 2 ) for u in x ]
448- seq_lens = torch .tensor ([u .size (1 ) for u in x ], dtype = torch .long )
449- if seq_len is not None :
450- assert seq_lens .max () <= seq_len
451- x = torch .cat ([
452- torch .cat ([u , u .new_zeros (1 , seq_len - u .size (1 ), u .size (2 ))],
453- dim = 1 ) for u in x
454- ])
455- elif len (x ) == 1 :
456- x = x [0 ]
435+ x = self .patch_embedding (x )
436+ grid_sizes = x .shape [2 :]
437+ x = x .flatten (2 ).transpose (1 , 2 )
457438
458439 # time embeddings
459440 e = self .time_embedding (
460441 sinusoidal_embedding_1d (self .freq_dim , t ).to (dtype = x [0 ].dtype ))
461442 e0 = self .time_projection (e ).unflatten (1 , (6 , self .dim ))
462443
463444 # context
464- context_lens = None
465- context = self .text_embedding (
466- torch .cat ([
467- torch .cat (
468- [u , u .new_zeros (u .size (0 ), self .text_len - u .size (1 ), u .size (2 ))], dim = 1 )
469- for u in context
470- ], dim = 0 ))
445+ context = self .text_embedding (torch .cat ([context , context .new_zeros (context .size (0 ), self .text_len - context .size (1 ), context .size (2 ))], dim = 1 ))
471446
472447 if clip_fea is not None :
473448 context_clip = self .img_emb (clip_fea ) # bs x 257 x dim
@@ -476,11 +451,8 @@ def forward_orig(
476451 # arguments
477452 kwargs = dict (
478453 e = e0 ,
479- seq_lens = seq_lens ,
480- grid_sizes = grid_sizes ,
481454 freqs = freqs ,
482- context = context ,
483- context_lens = context_lens )
455+ context = context )
484456
485457 for block in self .blocks :
486458 x = block (x , ** kwargs )
@@ -495,6 +467,7 @@ def forward_orig(
495467
496468 def forward (self , x , timestep , context , y = None , image = None , ** kwargs ):
497469 bs , c , t , h , w = x .shape
470+ x = comfy .ldm .common_dit .pad_to_patch_size (x , self .patch_size )
498471 patch_size = self .patch_size
499472 t_len = ((t + (patch_size [0 ] // 2 )) // patch_size [0 ])
500473 h_len = ((h + (patch_size [1 ] // 2 )) // patch_size [1 ])
@@ -506,7 +479,7 @@ def forward(self, x, timestep, context, y=None, image=None, **kwargs):
506479 img_ids = repeat (img_ids , "t h w c -> b (t h w) c" , b = bs )
507480
508481 freqs = self .rope_embedder (img_ids ).movedim (1 , 2 )
509- return self .forward_orig ([ x ] , timestep , [ context ] , clip_fea = y , y = image , freqs = freqs )[0 ]
482+ return self .forward_orig (x , timestep , context , clip_fea = y , y = image , freqs = freqs )[:, :, : t , : h , : w ]
510483
511484 def unpatchify (self , x , grid_sizes ):
512485 r"""
@@ -521,14 +494,13 @@ def unpatchify(self, x, grid_sizes):
521494
522495 Returns:
523496 List[Tensor]:
524- Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
497+ Reconstructed video tensors with shape [L, C_out, F, H / 8, W / 8]
525498 """
526499
527500 c = self .out_dim
528- out = []
529- for u , v in zip (x , grid_sizes .tolist ()):
530- u = u [:math .prod (v )].view (* v , * self .patch_size , c )
531- u = torch .einsum ('fhwpqrc->cfphqwr' , u )
532- u = u .reshape (c , * [i * j for i , j in zip (v , self .patch_size )])
533- out .append (u )
534- return out
501+ u = x
502+ b = u .shape [0 ]
503+ u = u [:, :math .prod (grid_sizes )].view (b , * grid_sizes , * self .patch_size , c )
504+ u = torch .einsum ('bfhwpqrc->bcfphqwr' , u )
505+ u = u .reshape (b , c , * [i * j for i , j in zip (grid_sizes , self .patch_size )])
506+ return u
0 commit comments