diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 52f607e10..43d8cd1de 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -527,8 +527,10 @@ def forward(self, clip_embdding = self.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) - x, (f, h, w) = self.patchify(x) - + x = self.patchify(x) + f, h, w = x.shape[-3], x.shape[-2], x.shape[-1] + x = x.flatten(2).transpose(1, 2) + freqs = torch.cat([ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index c1e4dfb3a..b96d20ee0 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -1658,11 +1658,15 @@ def model_fn_wans2v( # x and s2v_pose_latents s2v_pose_latents = torch.zeros_like(x) if s2v_pose_latents is None else s2v_pose_latents - x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(s2v_pose_latents)) + x = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(s2v_pose_latents)) + f, h, w = x.shape[-3], x.shape[-2], x.shape[-1] + x = x.flatten(2).transpose(1, 2) seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel # reference image - ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) + ref_latents = dit.patchify(dit.patch_embedding(origin_ref_latents)) + rf, rh, rw = ref_latents.shape[-3], ref_latents.shape[-2], ref_latents.shape[-1] + ref_latents = ref_latents.flatten(2).transpose(1, 2) grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw)) x = torch.cat([x, ref_latents], dim=1) # mask diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index abf0f3fef..40b18821a 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -81,8 +81,10 @@ def usp_dit_forward(self, clip_embdding = self.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) - x, (f, h, w) = self.patchify(x) - + x = self.patchify(x) + f, h, w = x.shape[-3], x.shape[-2], x.shape[-1] + x = x.flatten(2).transpose(1, 2) + freqs = torch.cat([ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),