1010from comfy .ldm .flux .layers import EmbedND
1111from comfy .ldm .flux .math import apply_rope
1212import comfy .ldm .common_dit
13+ import comfy .model_management
1314
1415def sinusoidal_embedding_1d (dim , position ):
1516 # preprocess
@@ -37,7 +38,7 @@ def forward(self, x):
3738 Args:
3839 x(Tensor): Shape [B, L, C]
3940 """
40- return self ._norm (x .float ()).type_as (x ) * self .weight
41+ return self ._norm (x .float ()).type_as (x ) * comfy . model_management . cast_to ( self .weight , dtype = x . dtype , device = x . device )
4142
4243 def _norm (self , x ):
4344 return x * torch .rsqrt (x .pow (2 ).mean (dim = - 1 , keepdim = True ) + self .eps )
@@ -125,7 +126,7 @@ def __init__(self,
125126 window_size = (- 1 , - 1 ),
126127 qk_norm = True ,
127128 eps = 1e-6 , operation_settings = {}):
128- super ().__init__ (dim , num_heads , window_size , qk_norm , eps )
129+ super ().__init__ (dim , num_heads , window_size , qk_norm , eps , operation_settings = operation_settings )
129130
130131 self .k_img = operation_settings .get ("operations" ).Linear (dim , dim , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
131132 self .v_img = operation_settings .get ("operations" ).Linear (dim , dim , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
@@ -218,7 +219,7 @@ def forward(
218219 """
219220 # assert e.dtype == torch.float32
220221
221- e = (self .modulation + e ).chunk (6 , dim = 1 )
222+ e = (comfy . model_management . cast_to ( self .modulation , dtype = x . dtype , device = x . device ) + e ).chunk (6 , dim = 1 )
222223 # assert e[0].dtype == torch.float32
223224
224225 # self-attention
@@ -263,7 +264,7 @@ def forward(self, x, e):
263264 e(Tensor): Shape [B, C]
264265 """
265266 # assert e.dtype == torch.float32
266- e = (self .modulation + e .unsqueeze (1 )).chunk (2 , dim = 1 )
267+ e = (comfy . model_management . cast_to ( self .modulation , dtype = x . dtype , device = x . device ) + e .unsqueeze (1 )).chunk (2 , dim = 1 )
267268 x = (self .head (self .norm (x ) * (1 + e [1 ]) + e [0 ]))
268269 return x
269270
@@ -401,7 +402,6 @@ def forward_orig(
401402 t ,
402403 context ,
403404 clip_fea = None ,
404- y = None ,
405405 freqs = None ,
406406 ):
407407 r"""
@@ -425,12 +425,6 @@ def forward_orig(
425425 List[Tensor]:
426426 List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
427427 """
428- if self .model_type == 'i2v' :
429- assert clip_fea is not None and y is not None
430-
431- if y is not None :
432- x = torch .cat ([x , y ], dim = 0 )
433-
434428 # embeddings
435429 x = self .patch_embedding (x )
436430 grid_sizes = x .shape [2 :]
@@ -465,7 +459,7 @@ def forward_orig(
465459 return x
466460 # return [u.float() for u in x]
467461
468- def forward (self , x , timestep , context , y = None , image = None , ** kwargs ):
462+ def forward (self , x , timestep , context , clip_fea = None , ** kwargs ):
469463 bs , c , t , h , w = x .shape
470464 x = comfy .ldm .common_dit .pad_to_patch_size (x , self .patch_size )
471465 patch_size = self .patch_size
@@ -479,7 +473,7 @@ def forward(self, x, timestep, context, y=None, image=None, **kwargs):
479473 img_ids = repeat (img_ids , "t h w c -> b (t h w) c" , b = bs )
480474
481475 freqs = self .rope_embedder (img_ids ).movedim (1 , 2 )
482- return self .forward_orig (x , timestep , context , clip_fea = y , y = image , freqs = freqs )[:, :, :t , :h , :w ]
476+ return self .forward_orig (x , timestep , context , clip_fea = clip_fea , freqs = freqs )[:, :, :t , :h , :w ]
483477
484478 def unpatchify (self , x , grid_sizes ):
485479 r"""
0 commit comments