Skip to content

Commit 6302301

Browse files
WIP support for Wan t2v model.
1 parent f400760 commit 6302301

10 files changed

Lines changed: 1307 additions & 3 deletions

File tree

comfy/latent_formats.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,3 +407,31 @@ class Cosmos1CV8x8x8(LatentFormat):
407407
]
408408

409409
latent_rgb_factors_bias = [-0.1223, -0.1889, -0.1976]
410+
411+
class Wan21(LatentFormat):
412+
latent_channels = 16
413+
latent_dimensions = 3
414+
415+
def __init__(self):
416+
self.scale_factor = 1.0
417+
self.latents_mean = torch.tensor([
418+
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
419+
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
420+
]).view(1, self.latent_channels, 1, 1, 1)
421+
self.latents_std = torch.tensor([
422+
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
423+
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
424+
]).view(1, self.latent_channels, 1, 1, 1)
425+
426+
427+
self.taesd_decoder_name = None #TODO
428+
429+
def process_in(self, latent):
430+
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
431+
latents_std = self.latents_std.to(latent.device, latent.dtype)
432+
return (latent - latents_mean) * self.scale_factor / latents_std
433+
434+
def process_out(self, latent):
435+
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
436+
latents_std = self.latents_std.to(latent.device, latent.dtype)
437+
return latent * latents_std / self.scale_factor + latents_mean

0 commit comments

Comments
 (0)