11#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
22
33import torch
4+ import math
45from torch import Tensor , nn
56from einops import rearrange , repeat
67
1314
1415
1516class ControlNetFlux (Flux ):
16- def __init__ (self , image_model = None , dtype = None , device = None , operations = None , ** kwargs ):
17+ def __init__ (self , latent_input = False , image_model = None , dtype = None , device = None , operations = None , ** kwargs ):
1718 super ().__init__ (final_layer = False , dtype = dtype , device = device , operations = operations , ** kwargs )
1819
20+ self .main_model_double = 19
21+ self .main_model_single = 38
1922 # add ControlNet blocks
2023 self .controlnet_blocks = nn .ModuleList ([])
2124 for _ in range (self .params .depth ):
2225 controlnet_block = operations .Linear (self .hidden_size , self .hidden_size , dtype = dtype , device = device )
2326 # controlnet_block = zero_module(controlnet_block)
2427 self .controlnet_blocks .append (controlnet_block )
25- self .pos_embed_input = operations .Linear (self .in_channels , self .hidden_size , bias = True , dtype = dtype , device = device )
2628 self .gradient_checkpointing = False
27- self .input_hint_block = nn .Sequential (
28- operations .Conv2d (3 , 16 , 3 , padding = 1 , dtype = dtype , device = device ),
29- nn .SiLU (),
30- operations .Conv2d (16 , 16 , 3 , padding = 1 , dtype = dtype , device = device ),
31- nn .SiLU (),
32- operations .Conv2d (16 , 16 , 3 , padding = 1 , stride = 2 , dtype = dtype , device = device ),
33- nn .SiLU (),
34- operations .Conv2d (16 , 16 , 3 , padding = 1 , dtype = dtype , device = device ),
35- nn .SiLU (),
36- operations .Conv2d (16 , 16 , 3 , padding = 1 , stride = 2 , dtype = dtype , device = device ),
37- nn .SiLU (),
38- operations .Conv2d (16 , 16 , 3 , padding = 1 , dtype = dtype , device = device ),
39- nn .SiLU (),
40- operations .Conv2d (16 , 16 , 3 , padding = 1 , stride = 2 , dtype = dtype , device = device ),
41- nn .SiLU (),
42- operations .Conv2d (16 , 16 , 3 , padding = 1 , dtype = dtype , device = device )
43- )
29+ self .latent_input = latent_input
30+ self .pos_embed_input = operations .Linear (self .in_channels , self .hidden_size , bias = True , dtype = dtype , device = device )
31+ if not self .latent_input :
32+ self .input_hint_block = nn .Sequential (
33+ operations .Conv2d (3 , 16 , 3 , padding = 1 , dtype = dtype , device = device ),
34+ nn .SiLU (),
35+ operations .Conv2d (16 , 16 , 3 , padding = 1 , dtype = dtype , device = device ),
36+ nn .SiLU (),
37+ operations .Conv2d (16 , 16 , 3 , padding = 1 , stride = 2 , dtype = dtype , device = device ),
38+ nn .SiLU (),
39+ operations .Conv2d (16 , 16 , 3 , padding = 1 , dtype = dtype , device = device ),
40+ nn .SiLU (),
41+ operations .Conv2d (16 , 16 , 3 , padding = 1 , stride = 2 , dtype = dtype , device = device ),
42+ nn .SiLU (),
43+ operations .Conv2d (16 , 16 , 3 , padding = 1 , dtype = dtype , device = device ),
44+ nn .SiLU (),
45+ operations .Conv2d (16 , 16 , 3 , padding = 1 , stride = 2 , dtype = dtype , device = device ),
46+ nn .SiLU (),
47+ operations .Conv2d (16 , 16 , 3 , padding = 1 , dtype = dtype , device = device )
48+ )
4449
4550 def forward_orig (
4651 self ,
@@ -58,8 +63,10 @@ def forward_orig(
5863
5964 # running on sequences img
6065 img = self .img_in (img )
61- controlnet_cond = self .input_hint_block (controlnet_cond )
62- controlnet_cond = rearrange (controlnet_cond , "b c (h ph) (w pw) -> b (h w) (c ph pw)" , ph = 2 , pw = 2 )
66+ if not self .latent_input :
67+ controlnet_cond = self .input_hint_block (controlnet_cond )
68+ controlnet_cond = rearrange (controlnet_cond , "b c (h ph) (w pw) -> b (h w) (c ph pw)" , ph = 2 , pw = 2 )
69+
6370 controlnet_cond = self .pos_embed_input (controlnet_cond )
6471 img = img + controlnet_cond
6572 vec = self .time_in (timestep_embedding (timesteps , 256 ))
@@ -82,13 +89,25 @@ def forward_orig(
8289 block_res_sample = controlnet_block (block_res_sample )
8390 controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample ,)
8491
85- return {"input" : (controlnet_block_res_samples * 10 )[:19 ]}
92+
93+ repeat = math .ceil (self .main_model_double / len (controlnet_block_res_samples ))
94+ if self .latent_input :
95+ out_input = ()
96+ for x in controlnet_block_res_samples :
97+ out_input += (x ,) * repeat
98+ else :
99+ out_input = (controlnet_block_res_samples * repeat )
100+ return {"input" : out_input [:self .main_model_double ]}
86101
87102 def forward (self , x , timesteps , context , y , guidance = None , hint = None , ** kwargs ):
88- hint = hint * 2.0 - 1.0
103+ patch_size = 2
104+ if self .latent_input :
105+ hint = comfy .ldm .common_dit .pad_to_patch_size (hint , (patch_size , patch_size ))
106+ hint = rearrange (hint , "b c (h ph) (w pw) -> b (h w) (c ph pw)" , ph = patch_size , pw = patch_size )
107+ else :
108+ hint = hint * 2.0 - 1.0
89109
90110 bs , c , h , w = x .shape
91- patch_size = 2
92111 x = comfy .ldm .common_dit .pad_to_patch_size (x , (patch_size , patch_size ))
93112
94113 img = rearrange (x , "b c (h ph) (w pw) -> b (h w) (c ph pw)" , ph = patch_size , pw = patch_size )
0 commit comments