Skip to content

Commit 9a66bb9

Browse files
Make wan work with all latent resolutions.
Cleanup some code.
1 parent ea0f939 commit 9a66bb9

1 file changed

Lines changed: 26 additions & 54 deletions

File tree

comfy/ldm/wan/model.py

Lines changed: 26 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from comfy.ldm.modules.attention import optimized_attention
1010
from comfy.ldm.flux.layers import EmbedND
1111
from comfy.ldm.flux.math import apply_rope
12+
import comfy.ldm.common_dit
1213

1314
def sinusoidal_embedding_1d(dim, position):
1415
# preprocess
@@ -67,12 +68,10 @@ def __init__(self,
6768
self.norm_q = WanRMSNorm(dim, eps=eps, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
6869
self.norm_k = WanRMSNorm(dim, eps=eps, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
6970

70-
def forward(self, x, seq_lens, grid_sizes, freqs):
71+
def forward(self, x, freqs):
7172
r"""
7273
Args:
7374
x(Tensor): Shape [B, L, num_heads, C / num_heads]
74-
seq_lens(Tensor): Shape [B]
75-
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
7675
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
7776
"""
7877
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
@@ -100,12 +99,11 @@ def qkv_fn(x):
10099

101100
class WanT2VCrossAttention(WanSelfAttention):
102101

103-
def forward(self, x, context, context_lens):
102+
def forward(self, x, context):
104103
r"""
105104
Args:
106105
x(Tensor): Shape [B, L1, C]
107106
context(Tensor): Shape [B, L2, C]
108-
context_lens(Tensor): Shape [B]
109107
"""
110108
# compute query, key, value
111109
q = self.norm_q(self.q(x))
@@ -134,12 +132,11 @@ def __init__(self,
134132
# self.alpha = nn.Parameter(torch.zeros((1, )))
135133
self.norm_k_img = WanRMSNorm(dim, eps=eps, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
136134

137-
def forward(self, x, context, context_lens):
135+
def forward(self, x, context):
138136
r"""
139137
Args:
140138
x(Tensor): Shape [B, L1, C]
141139
context(Tensor): Shape [B, L2, C]
142-
context_lens(Tensor): Shape [B]
143140
"""
144141
context_img = context[:, :257]
145142
context = context[:, 257:]
@@ -210,18 +207,13 @@ def forward(
210207
self,
211208
x,
212209
e,
213-
seq_lens,
214-
grid_sizes,
215210
freqs,
216211
context,
217-
context_lens,
218212
):
219213
r"""
220214
Args:
221215
x(Tensor): Shape [B, L, C]
222216
e(Tensor): Shape [B, 6, C]
223-
seq_lens(Tensor): Shape [B], length of each sequence in batch
224-
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
225217
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
226218
"""
227219
# assert e.dtype == torch.float32
@@ -231,19 +223,19 @@ def forward(
231223

232224
# self-attention
233225
y = self.self_attn(
234-
self.norm1(x) * (1 + e[1]) + e[0], seq_lens, grid_sizes,
226+
self.norm1(x) * (1 + e[1]) + e[0],
235227
freqs)
236228

237229
x = x + y * e[2]
238230

239231
# cross-attention & ffn function
240-
def cross_attn_ffn(x, context, context_lens, e):
241-
x = x + self.cross_attn(self.norm3(x), context, context_lens)
232+
def cross_attn_ffn(x, context, e):
233+
x = x + self.cross_attn(self.norm3(x), context)
242234
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
243235
x = x + y * e[5]
244236
return x
245237

246-
x = cross_attn_ffn(x, context, context_lens, e)
238+
x = cross_attn_ffn(x, context, e)
247239
return x
248240

249241

@@ -408,7 +400,6 @@ def forward_orig(
408400
x,
409401
t,
410402
context,
411-
seq_len=None,
412403
clip_fea=None,
413404
y=None,
414405
freqs=None,
@@ -417,12 +408,12 @@ def forward_orig(
417408
Forward pass through the diffusion model
418409
419410
Args:
420-
x (List[Tensor]):
421-
List of input video tensors, each with shape [C_in, F, H, W]
411+
x (Tensor):
412+
List of input video tensors with shape [B, C_in, F, H, W]
422413
t (Tensor):
423414
Diffusion timesteps tensor of shape [B]
424415
context (List[Tensor]):
425-
List of text embeddings each with shape [L, C]
416+
List of text embeddings each with shape [B, L, C]
426417
seq_len (`int`):
427418
Maximum sequence length for positional encoding
428419
clip_fea (Tensor, *optional*):
@@ -438,36 +429,20 @@ def forward_orig(
438429
assert clip_fea is not None and y is not None
439430

440431
if y is not None:
441-
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
432+
x = torch.cat([x, y], dim=0)
442433

443434
# embeddings
444-
x = [self.patch_embedding(u) for u in x]
445-
grid_sizes = torch.stack(
446-
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
447-
x = [u.flatten(2).transpose(1, 2) for u in x]
448-
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
449-
if seq_len is not None:
450-
assert seq_lens.max() <= seq_len
451-
x = torch.cat([
452-
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
453-
dim=1) for u in x
454-
])
455-
elif len(x) == 1:
456-
x = x[0]
435+
x = self.patch_embedding(x)
436+
grid_sizes = x.shape[2:]
437+
x = x.flatten(2).transpose(1, 2)
457438

458439
# time embeddings
459440
e = self.time_embedding(
460441
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
461442
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
462443

463444
# context
464-
context_lens = None
465-
context = self.text_embedding(
466-
torch.cat([
467-
torch.cat(
468-
[u, u.new_zeros(u.size(0), self.text_len - u.size(1), u.size(2))], dim=1)
469-
for u in context
470-
], dim=0))
445+
context = self.text_embedding(torch.cat([context, context.new_zeros(context.size(0), self.text_len - context.size(1), context.size(2))], dim=1))
471446

472447
if clip_fea is not None:
473448
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
@@ -476,11 +451,8 @@ def forward_orig(
476451
# arguments
477452
kwargs = dict(
478453
e=e0,
479-
seq_lens=seq_lens,
480-
grid_sizes=grid_sizes,
481454
freqs=freqs,
482-
context=context,
483-
context_lens=context_lens)
455+
context=context)
484456

485457
for block in self.blocks:
486458
x = block(x, **kwargs)
@@ -495,6 +467,7 @@ def forward_orig(
495467

496468
def forward(self, x, timestep, context, y=None, image=None, **kwargs):
497469
bs, c, t, h, w = x.shape
470+
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
498471
patch_size = self.patch_size
499472
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
500473
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
@@ -506,7 +479,7 @@ def forward(self, x, timestep, context, y=None, image=None, **kwargs):
506479
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
507480

508481
freqs = self.rope_embedder(img_ids).movedim(1, 2)
509-
return self.forward_orig([x], timestep, [context], clip_fea=y, y=image, freqs=freqs)[0]
482+
return self.forward_orig(x, timestep, context, clip_fea=y, y=image, freqs=freqs)[:, :, :t, :h, :w]
510483

511484
def unpatchify(self, x, grid_sizes):
512485
r"""
@@ -521,14 +494,13 @@ def unpatchify(self, x, grid_sizes):
521494
522495
Returns:
523496
List[Tensor]:
524-
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
497+
Reconstructed video tensors with shape [L, C_out, F, H / 8, W / 8]
525498
"""
526499

527500
c = self.out_dim
528-
out = []
529-
for u, v in zip(x, grid_sizes.tolist()):
530-
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
531-
u = torch.einsum('fhwpqrc->cfphqwr', u)
532-
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
533-
out.append(u)
534-
return out
501+
u = x
502+
b = u.shape[0]
503+
u = u[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c)
504+
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
505+
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
506+
return u

0 commit comments

Comments
 (0)