1414
1515
1616class ControlNetFlux (Flux ):
17- def __init__ (self , latent_input = False , image_model = None , dtype = None , device = None , operations = None , ** kwargs ):
17+ def __init__ (self , latent_input = False , num_union_modes = 0 , image_model = None , dtype = None , device = None , operations = None , ** kwargs ):
1818 super ().__init__ (final_layer = False , dtype = dtype , device = device , operations = operations , ** kwargs )
1919
2020 self .main_model_double = 19
@@ -29,6 +29,11 @@ def __init__(self, latent_input=False, image_model=None, dtype=None, device=None
2929 for _ in range (self .params .depth_single_blocks ):
3030 self .controlnet_single_blocks .append (operations .Linear (self .hidden_size , self .hidden_size , dtype = dtype , device = device ))
3131
32+ self .num_union_modes = num_union_modes
33+ self .controlnet_mode_embedder = None
34+ if self .num_union_modes > 0 :
35+ self .controlnet_mode_embedder = operations .Embedding (self .num_union_modes , self .hidden_size , dtype = dtype , device = device )
36+
3237 self .gradient_checkpointing = False
3338 self .latent_input = latent_input
3439 self .pos_embed_input = operations .Linear (self .in_channels , self .hidden_size , bias = True , dtype = dtype , device = device )
@@ -61,6 +66,7 @@ def forward_orig(
6166 timesteps : Tensor ,
6267 y : Tensor ,
6368 guidance : Tensor = None ,
69+ control_type : Tensor = None ,
6470 ) -> Tensor :
6571 if img .ndim != 3 or txt .ndim != 3 :
6672 raise ValueError ("Input img and txt tensors must have 3 dimensions." )
@@ -79,6 +85,11 @@ def forward_orig(
7985 vec = vec + self .vector_in (y )
8086 txt = self .txt_in (txt )
8187
88+ if self .controlnet_mode_embedder is not None and len (control_type ) > 0 :
89+ control_cond = self .controlnet_mode_embedder (torch .tensor (control_type , device = img .device ), out_dtype = img .dtype ).unsqueeze (0 ).repeat ((txt .shape [0 ], 1 , 1 ))
90+ txt = torch .cat ([control_cond , txt ], dim = 1 )
91+ txt_ids = torch .cat ([txt_ids [:,:1 ], txt_ids ], dim = 1 )
92+
8293 ids = torch .cat ((txt_ids , img_ids ), dim = 1 )
8394 pe = self .pe_embedder (ids )
8495
@@ -137,4 +148,4 @@ def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
137148 img_ids = repeat (img_ids , "h w c -> b (h w) c" , b = bs )
138149
139150 txt_ids = torch .zeros ((bs , context .shape [1 ], 3 ), device = x .device , dtype = x .dtype )
140- return self .forward_orig (img , img_ids , hint , context , txt_ids , timesteps , y , guidance )
151+ return self .forward_orig (img , img_ids , hint , context , txt_ids , timesteps , y , guidance , control_type = kwargs . get ( "control_type" , []) )
0 commit comments