1414
1515
1616class ControlNetFlux (Flux ):
17- def __init__ (self , latent_input = False , image_model = None , dtype = None , device = None , operations = None , ** kwargs ):
17+ def __init__ (self , latent_input = False , num_union_modes = 0 , image_model = None , dtype = None , device = None , operations = None , ** kwargs ):
1818 super ().__init__ (final_layer = False , dtype = dtype , device = device , operations = operations , ** kwargs )
1919
2020 self .main_model_double = 19
@@ -23,8 +23,17 @@ def __init__(self, latent_input=False, image_model=None, dtype=None, device=None
2323 self .controlnet_blocks = nn .ModuleList ([])
2424 for _ in range (self .params .depth ):
2525 controlnet_block = operations .Linear (self .hidden_size , self .hidden_size , dtype = dtype , device = device )
26- # controlnet_block = zero_module(controlnet_block)
2726 self .controlnet_blocks .append (controlnet_block )
27+
28+ self .controlnet_single_blocks = nn .ModuleList ([])
29+ for _ in range (self .params .depth_single_blocks ):
30+ self .controlnet_single_blocks .append (operations .Linear (self .hidden_size , self .hidden_size , dtype = dtype , device = device ))
31+
32+ self .num_union_modes = num_union_modes
33+ self .controlnet_mode_embedder = None
34+ if self .num_union_modes > 0 :
35+ self .controlnet_mode_embedder = operations .Embedding (self .num_union_modes , self .hidden_size , dtype = dtype , device = device )
36+
2837 self .gradient_checkpointing = False
2938 self .latent_input = latent_input
3039 self .pos_embed_input = operations .Linear (self .in_channels , self .hidden_size , bias = True , dtype = dtype , device = device )
@@ -57,6 +66,7 @@ def forward_orig(
5766 timesteps : Tensor ,
5867 y : Tensor ,
5968 guidance : Tensor = None ,
69+ control_type : Tensor = None ,
6070 ) -> Tensor :
6171 if img .ndim != 3 or txt .ndim != 3 :
6272 raise ValueError ("Input img and txt tensors must have 3 dimensions." )
@@ -75,29 +85,47 @@ def forward_orig(
7585 vec = vec + self .vector_in (y )
7686 txt = self .txt_in (txt )
7787
88+ if self .controlnet_mode_embedder is not None and len (control_type ) > 0 :
89+ control_cond = self .controlnet_mode_embedder (torch .tensor (control_type , device = img .device ), out_dtype = img .dtype ).unsqueeze (0 ).repeat ((txt .shape [0 ], 1 , 1 ))
90+ txt = torch .cat ([control_cond , txt ], dim = 1 )
91+ txt_ids = torch .cat ([txt_ids [:,:1 ], txt_ids ], dim = 1 )
92+
7893 ids = torch .cat ((txt_ids , img_ids ), dim = 1 )
7994 pe = self .pe_embedder (ids )
8095
81- block_res_samples = ()
96+ controlnet_double = ()
97+
98+ for i in range (len (self .double_blocks )):
99+ img , txt = self .double_blocks [i ](img = img , txt = txt , vec = vec , pe = pe )
100+ controlnet_double = controlnet_double + (self .controlnet_blocks [i ](img ),)
82101
83- for block in self .double_blocks :
84- img , txt = block (img = img , txt = txt , vec = vec , pe = pe )
85- block_res_samples = block_res_samples + (img ,)
102+ img = torch .cat ((txt , img ), 1 )
86103
87- controlnet_block_res_samples = ()
88- for block_res_sample , controlnet_block in zip (block_res_samples , self .controlnet_blocks ):
89- block_res_sample = controlnet_block (block_res_sample )
90- controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample ,)
104+ controlnet_single = ()
91105
106+ for i in range (len (self .single_blocks )):
107+ img = self .single_blocks [i ](img , vec = vec , pe = pe )
108+ controlnet_single = controlnet_single + (self .controlnet_single_blocks [i ](img [:, txt .shape [1 ] :, ...]),)
92109
93- repeat = math .ceil (self .main_model_double / len (controlnet_block_res_samples ))
110+ repeat = math .ceil (self .main_model_double / len (controlnet_double ))
94111 if self .latent_input :
95112 out_input = ()
96- for x in controlnet_block_res_samples :
113+ for x in controlnet_double :
97114 out_input += (x ,) * repeat
98115 else :
99- out_input = (controlnet_block_res_samples * repeat )
100- return {"input" : out_input [:self .main_model_double ]}
116+ out_input = (controlnet_double * repeat )
117+
118+ out = {"input" : out_input [:self .main_model_double ]}
119+ if len (controlnet_single ) > 0 :
120+ repeat = math .ceil (self .main_model_single / len (controlnet_single ))
121+ out_output = ()
122+ if self .latent_input :
123+ for x in controlnet_single :
124+ out_output += (x ,) * repeat
125+ else :
126+ out_output = (controlnet_single * repeat )
127+ out ["output" ] = out_output [:self .main_model_single ]
128+ return out
101129
102130 def forward (self , x , timesteps , context , y , guidance = None , hint = None , ** kwargs ):
103131 patch_size = 2
@@ -120,4 +148,4 @@ def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
120148 img_ids = repeat (img_ids , "h w c -> b (h w) c" , b = bs )
121149
122150 txt_ids = torch .zeros ((bs , context .shape [1 ], 3 ), device = x .device , dtype = x .dtype )
123- return self .forward_orig (img , img_ids , hint , context , txt_ids , timesteps , y , guidance )
151+ return self .forward_orig (img , img_ids , hint , context , txt_ids , timesteps , y , guidance , control_type = kwargs . get ( "control_type" , []) )
0 commit comments