@@ -407,3 +407,31 @@ class Cosmos1CV8x8x8(LatentFormat):
407407 ]
408408
409409 latent_rgb_factors_bias = [- 0.1223 , - 0.1889 , - 0.1976 ]
410+
411+ class Wan21 (LatentFormat ):
412+ latent_channels = 16
413+ latent_dimensions = 3
414+
415+ def __init__ (self ):
416+ self .scale_factor = 1.0
417+ self .latents_mean = torch .tensor ([
418+ - 0.7571 , - 0.7089 , - 0.9113 , 0.1075 , - 0.1745 , 0.9653 , - 0.1517 , 1.5508 ,
419+ 0.4134 , - 0.0715 , 0.5517 , - 0.3632 , - 0.1922 , - 0.9497 , 0.2503 , - 0.2921
420+ ]).view (1 , self .latent_channels , 1 , 1 , 1 )
421+ self .latents_std = torch .tensor ([
422+ 2.8184 , 1.4541 , 2.3275 , 2.6558 , 1.2196 , 1.7708 , 2.6052 , 2.0743 ,
423+ 3.2687 , 2.1526 , 2.8652 , 1.5579 , 1.6382 , 1.1253 , 2.8251 , 1.9160
424+ ]).view (1 , self .latent_channels , 1 , 1 , 1 )
425+
426+
427+ self .taesd_decoder_name = None #TODO
428+
429+ def process_in (self , latent ):
430+ latents_mean = self .latents_mean .to (latent .device , latent .dtype )
431+ latents_std = self .latents_std .to (latent .device , latent .dtype )
432+ return (latent - latents_mean ) * self .scale_factor / latents_std
433+
434+ def process_out (self , latent ):
435+ latents_mean = self .latents_mean .to (latent .device , latent .dtype )
436+ latents_std = self .latents_std .to (latent .device , latent .dtype )
437+ return latent * latents_std / self .scale_factor + latents_mean
0 commit comments