Skip to content

Commit 4ced06b

Browse files
WIP support for Wan I2V model.
1 parent cb06e96 commit 4ced06b

6 files changed

Lines changed: 116 additions & 17 deletions

File tree

comfy/ldm/wan/model.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from comfy.ldm.flux.layers import EmbedND
1111
from comfy.ldm.flux.math import apply_rope
1212
import comfy.ldm.common_dit
13+
import comfy.model_management
1314

1415
def sinusoidal_embedding_1d(dim, position):
1516
# preprocess
@@ -37,7 +38,7 @@ def forward(self, x):
3738
Args:
3839
x(Tensor): Shape [B, L, C]
3940
"""
40-
return self._norm(x.float()).type_as(x) * self.weight
41+
return self._norm(x.float()).type_as(x) * comfy.model_management.cast_to(self.weight, dtype=x.dtype, device=x.device)
4142

4243
def _norm(self, x):
4344
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
@@ -125,7 +126,7 @@ def __init__(self,
125126
window_size=(-1, -1),
126127
qk_norm=True,
127128
eps=1e-6, operation_settings={}):
128-
super().__init__(dim, num_heads, window_size, qk_norm, eps)
129+
super().__init__(dim, num_heads, window_size, qk_norm, eps, operation_settings=operation_settings)
129130

130131
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
131132
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
@@ -218,7 +219,7 @@ def forward(
218219
"""
219220
# assert e.dtype == torch.float32
220221

221-
e = (self.modulation + e).chunk(6, dim=1)
222+
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
222223
# assert e[0].dtype == torch.float32
223224

224225
# self-attention
@@ -263,7 +264,7 @@ def forward(self, x, e):
263264
e(Tensor): Shape [B, C]
264265
"""
265266
# assert e.dtype == torch.float32
266-
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
267+
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
267268
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
268269
return x
269270

@@ -401,7 +402,6 @@ def forward_orig(
401402
t,
402403
context,
403404
clip_fea=None,
404-
y=None,
405405
freqs=None,
406406
):
407407
r"""
@@ -425,12 +425,6 @@ def forward_orig(
425425
List[Tensor]:
426426
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
427427
"""
428-
if self.model_type == 'i2v':
429-
assert clip_fea is not None and y is not None
430-
431-
if y is not None:
432-
x = torch.cat([x, y], dim=0)
433-
434428
# embeddings
435429
x = self.patch_embedding(x)
436430
grid_sizes = x.shape[2:]
@@ -465,7 +459,7 @@ def forward_orig(
465459
return x
466460
# return [u.float() for u in x]
467461

468-
def forward(self, x, timestep, context, y=None, image=None, **kwargs):
462+
def forward(self, x, timestep, context, clip_fea=None, **kwargs):
469463
bs, c, t, h, w = x.shape
470464
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
471465
patch_size = self.patch_size
@@ -479,7 +473,7 @@ def forward(self, x, timestep, context, y=None, image=None, **kwargs):
479473
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
480474

481475
freqs = self.rope_embedder(img_ids).movedim(1, 2)
482-
return self.forward_orig(x, timestep, context, clip_fea=y, y=image, freqs=freqs)[:, :, :t, :h, :w]
476+
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w]
483477

484478
def unpatchify(self, x, grid_sizes):
485479
r"""

comfy/model_base.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -929,13 +929,45 @@ def extra_conds(self, **kwargs):
929929
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
930930
return out
931931

932-
class WAN21_T2V(BaseModel):
933-
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
932+
class WAN21(BaseModel):
933+
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
934934
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
935+
self.image_to_video = image_to_video
936+
937+
def concat_cond(self, **kwargs):
938+
if not self.image_to_video:
939+
return None
940+
941+
image = kwargs.get("concat_latent_image", None)
942+
noise = kwargs.get("noise", None)
943+
device = kwargs["device"]
944+
945+
if image is None:
946+
image = torch.zeros_like(noise)
947+
948+
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
949+
image = self.process_latent_in(image)
950+
image = utils.resize_to_batch_size(image, noise.shape[0])
951+
952+
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
953+
if mask is None:
954+
mask = torch.zeros_like(noise)[:, :4]
955+
else:
956+
mask = 1.0 - torch.mean(mask, dim=1, keepdim=True)
957+
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
958+
if mask.shape[-3] < noise.shape[-3]:
959+
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
960+
mask = mask.view(mask.shape[0], -1, 4, mask.shape[-2], mask.shape[-1]).transpose(1, 2)
961+
mask = utils.resize_to_batch_size(mask, noise.shape[0])
962+
return torch.cat((mask, image), dim=1)
935963

936964
def extra_conds(self, **kwargs):
937965
out = super().extra_conds(**kwargs)
938966
cross_attn = kwargs.get("cross_attn", None)
939967
if cross_attn is not None:
940968
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
969+
970+
clip_vision_output = kwargs.get("clip_vision_output", None)
971+
if clip_vision_output is not None:
972+
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
941973
return out

comfy/model_detection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ def detect_unet_config(state_dict, key_prefix):
313313
dit_config["qk_norm"] = True
314314
dit_config["cross_attn_norm"] = True
315315
dit_config["eps"] = 1e-6
316+
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
316317
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
317318
dit_config["model_type"] = "i2v"
318319
else:

comfy/supported_models.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -917,14 +917,24 @@ class WAN21_T2V(supported_models_base.BASE):
917917
text_encoder_key_prefix = ["text_encoders."]
918918

919919
def get_model(self, state_dict, prefix="", device=None):
920-
out = model_base.WAN21_T2V(self, device=device)
920+
out = model_base.WAN21(self, device=device)
921921
return out
922922

923923
def clip_target(self, state_dict={}):
924924
pref = self.text_encoder_key_prefix[0]
925925
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
926926
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
927927

928-
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V]
928+
class WAN21_I2V(WAN21_T2V):
929+
unet_config = {
930+
"image_model": "wan2.1",
931+
"model_type": "i2v",
932+
}
933+
934+
def get_model(self, state_dict, prefix="", device=None):
935+
out = model_base.WAN21(self, image_to_video=True, device=device)
936+
return out
937+
938+
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V]
929939

930940
models += [SVD_img2vid]

comfy_extras/nodes_wan.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import nodes
2+
import node_helpers
3+
import torch
4+
import comfy.model_management
5+
import comfy.utils
6+
7+
8+
def masked_images(num_images):
9+
rem = 4 - (num_images % 4)
10+
if rem == 4:
11+
return num_images
12+
return rem + num_images
13+
14+
15+
class WanImageToVideo:
16+
@classmethod
17+
def INPUT_TYPES(s):
18+
return {"required": {"positive": ("CONDITIONING", ),
19+
"negative": ("CONDITIONING", ),
20+
"vae": ("VAE", ),
21+
"width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
22+
"height": ("INT", {"default": 720, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
23+
"length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
24+
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
25+
},
26+
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
27+
"start_image": ("IMAGE", ),
28+
}}
29+
30+
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
31+
RETURN_NAMES = ("positive", "negative", "latent")
32+
FUNCTION = "encode"
33+
34+
CATEGORY = "conditioning/video_models"
35+
36+
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None):
37+
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
38+
if start_image is not None:
39+
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
40+
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
41+
image[:start_image.shape[0]] = start_image
42+
43+
concat_latent_image = vae.encode(image[:, :, :, :3])
44+
mask = torch.ones((1, 1, latent.shape[2] * 4, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
45+
mask[:, :, :masked_images(start_image.shape[0])] = 0.0
46+
47+
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
48+
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
49+
50+
if clip_vision_output is not None:
51+
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
52+
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
53+
54+
out_latent = {}
55+
out_latent["samples"] = latent
56+
return (positive, negative, out_latent)
57+
58+
59+
NODE_CLASS_MAPPINGS = {
60+
"WanImageToVideo": WanImageToVideo,
61+
}

nodes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2269,6 +2269,7 @@ def init_builtin_extra_nodes():
22692269
"nodes_cosmos.py",
22702270
"nodes_video.py",
22712271
"nodes_lumina2.py",
2272+
"nodes_wan.py",
22722273
]
22732274

22742275
import_failed = []

0 commit comments

Comments
 (0)