@@ -23,8 +23,12 @@ def __init__(self, latent_input=False, image_model=None, dtype=None, device=None
2323 self .controlnet_blocks = nn .ModuleList ([])
2424 for _ in range (self .params .depth ):
2525 controlnet_block = operations .Linear (self .hidden_size , self .hidden_size , dtype = dtype , device = device )
26- # controlnet_block = zero_module(controlnet_block)
2726 self .controlnet_blocks .append (controlnet_block )
27+
28+ self .controlnet_single_blocks = nn .ModuleList ([])
29+ for _ in range (self .params .depth_single_blocks ):
30+ self .controlnet_single_blocks .append (operations .Linear (self .hidden_size , self .hidden_size , dtype = dtype , device = device ))
31+
2832 self .gradient_checkpointing = False
2933 self .latent_input = latent_input
3034 self .pos_embed_input = operations .Linear (self .in_channels , self .hidden_size , bias = True , dtype = dtype , device = device )
@@ -78,26 +82,39 @@ def forward_orig(
7882 ids = torch .cat ((txt_ids , img_ids ), dim = 1 )
7983 pe = self .pe_embedder (ids )
8084
81- block_res_samples = ()
85+ controlnet_double = ()
86+
87+ for i in range (len (self .double_blocks )):
88+ img , txt = self .double_blocks [i ](img = img , txt = txt , vec = vec , pe = pe )
89+ controlnet_double = controlnet_double + (self .controlnet_blocks [i ](img ),)
8290
83- for block in self .double_blocks :
84- img , txt = block (img = img , txt = txt , vec = vec , pe = pe )
85- block_res_samples = block_res_samples + (img ,)
91+ img = torch .cat ((txt , img ), 1 )
8692
87- controlnet_block_res_samples = ()
88- for block_res_sample , controlnet_block in zip (block_res_samples , self .controlnet_blocks ):
89- block_res_sample = controlnet_block (block_res_sample )
90- controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample ,)
93+ controlnet_single = ()
9194
95+ for i in range (len (self .single_blocks )):
96+ img = self .single_blocks [i ](img , vec = vec , pe = pe )
97+ controlnet_single = controlnet_single + (self .controlnet_single_blocks [i ](img [:, txt .shape [1 ] :, ...]),)
9298
93- repeat = math .ceil (self .main_model_double / len (controlnet_block_res_samples ))
99+ repeat = math .ceil (self .main_model_double / len (controlnet_double ))
94100 if self .latent_input :
95101 out_input = ()
96- for x in controlnet_block_res_samples :
102+ for x in controlnet_double :
97103 out_input += (x ,) * repeat
98104 else :
99- out_input = (controlnet_block_res_samples * repeat )
100- return {"input" : out_input [:self .main_model_double ]}
105+ out_input = (controlnet_double * repeat )
106+
107+ out = {"input" : out_input [:self .main_model_double ]}
108+ if len (controlnet_single ) > 0 :
109+ repeat = math .ceil (self .main_model_single / len (controlnet_single ))
110+ out_output = ()
111+ if self .latent_input :
112+ for x in controlnet_single :
113+ out_output += (x ,) * repeat
114+ else :
115+ out_output = (controlnet_single * repeat )
116+ out ["output" ] = out_output [:self .main_model_single ]
117+ return out
101118
102119 def forward (self , x , timesteps , context , y , guidance = None , hint = None , ** kwargs ):
103120 patch_size = 2
0 commit comments