Skip to content

Commit f37551c

Browse files
Change wan rope implementation to the flux one.
Should be more compatible.
1 parent 6302301 commit f37551c

1 file changed

Lines changed: 23 additions & 56 deletions

File tree

comfy/ldm/wan/model.py

Lines changed: 23 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
import torch
66
import torch.nn as nn
7+
from einops import repeat
78

89
from comfy.ldm.modules.attention import optimized_attention
9-
10+
from comfy.ldm.flux.layers import EmbedND
11+
from comfy.ldm.flux.math import apply_rope
1012

1113
def sinusoidal_embedding_1d(dim, position):
1214
# preprocess
@@ -21,45 +23,6 @@ def sinusoidal_embedding_1d(dim, position):
2123
return x
2224

2325

24-
def rope_params(max_seq_len, dim, theta=10000):
25-
assert dim % 2 == 0
26-
freqs = torch.outer(
27-
torch.arange(max_seq_len),
28-
1.0 / torch.pow(theta,
29-
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
30-
freqs = torch.polar(torch.ones_like(freqs), freqs)
31-
return freqs
32-
33-
34-
def rope_apply(x, grid_sizes, freqs):
35-
n, c = x.size(2), x.size(3) // 2
36-
37-
# split freqs
38-
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
39-
40-
# loop over samples
41-
output = []
42-
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
43-
seq_len = f * h * w
44-
45-
# precompute multipliers
46-
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
47-
seq_len, n, -1, 2))
48-
freqs_i = torch.cat([
49-
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
50-
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
51-
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
52-
], dim=-1).reshape(seq_len, 1, -1)
53-
54-
# apply rotary embedding
55-
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
56-
x_i = torch.cat([x_i, x[i, seq_len:]])
57-
58-
# append to collection
59-
output.append(x_i)
60-
return torch.stack(output).to(dtype=x.dtype)
61-
62-
6326
class WanRMSNorm(nn.Module):
6427

6528
def __init__(self, dim, eps=1e-5, device=None, dtype=None):
@@ -122,10 +85,11 @@ def qkv_fn(x):
12285
return q, k, v
12386

12487
q, k, v = qkv_fn(x)
88+
q, k = apply_rope(q, k, freqs)
12589

12690
x = optimized_attention(
127-
q=rope_apply(q, grid_sizes, freqs).view(b, s, n * d),
128-
k=rope_apply(k, grid_sizes, freqs).view(b, s, n * d),
91+
q=q.view(b, s, n * d),
92+
k=k.view(b, s, n * d),
12993
v=v,
13094
heads=self.num_heads,
13195
)
@@ -433,14 +397,8 @@ def __init__(self,
433397
# head
434398
self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings)
435399

436-
# buffers (don't use register_buffer otherwise dtype will be changed in to())
437-
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
438400
d = dim // num_heads
439-
self.register_buffer("freqs", torch.cat([
440-
rope_params(1024, d - 4 * (d // 6)),
441-
rope_params(1024, 2 * (d // 6)),
442-
rope_params(1024, 2 * (d // 6))
443-
], dim=1), persistent=False)
401+
self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
444402

445403
if model_type == 'i2v':
446404
self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings)
@@ -453,6 +411,7 @@ def forward_orig(
453411
seq_len=None,
454412
clip_fea=None,
455413
y=None,
414+
freqs=None,
456415
):
457416
r"""
458417
Forward pass through the diffusion model
@@ -477,10 +436,6 @@ def forward_orig(
477436
"""
478437
if self.model_type == 'i2v':
479438
assert clip_fea is not None and y is not None
480-
# params
481-
# device = self.patch_embedding.weight.device
482-
# if self.freqs.device != device:
483-
# self.freqs = self.freqs.to(device)
484439

485440
if y is not None:
486441
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
@@ -523,7 +478,7 @@ def forward_orig(
523478
e=e0,
524479
seq_lens=seq_lens,
525480
grid_sizes=grid_sizes,
526-
freqs=self.freqs,
481+
freqs=freqs,
527482
context=context,
528483
context_lens=context_lens)
529484

@@ -538,8 +493,20 @@ def forward_orig(
538493
return x
539494
# return [u.float() for u in x]
540495

541-
def forward(self, x, t, context, y=None, image=None, **kwargs):
542-
return self.forward_orig([x], t, [context], clip_fea=y, y=image)[0]
496+
def forward(self, x, timestep, context, y=None, image=None, **kwargs):
497+
bs, c, t, h, w = x.shape
498+
patch_size = self.patch_size
499+
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
500+
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
501+
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
502+
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
503+
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
504+
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
505+
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
506+
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
507+
508+
freqs = self.rope_embedder(img_ids).movedim(1, 2)
509+
return self.forward_orig([x], timestep, [context], clip_fea=y, y=image, freqs=freqs)[0]
543510

544511
def unpatchify(self, x, grid_sizes):
545512
r"""

0 commit comments

Comments
 (0)