Skip to content

Commit ea3f39b

Browse files
InstantX depth flux controlnet.
1 parent b33cd61 commit ea3f39b

2 files changed

Lines changed: 32 additions & 16 deletions

File tree

comfy/controlnet.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def control_merge(self, control, control_prev, output_dtype):
148148
elif self.strength_type == StrengthType.LINEAR_UP:
149149
x *= (self.strength ** float(len(control_output) - i))
150150

151-
if x.dtype != output_dtype:
151+
if output_dtype is not None and x.dtype != output_dtype:
152152
x = x.to(output_dtype)
153153

154154
out[key].append(x)
@@ -206,7 +206,6 @@ def get_control(self, x_noisy, t, cond, batched_number):
206206
if self.manual_cast_dtype is not None:
207207
dtype = self.manual_cast_dtype
208208

209-
output_dtype = x_noisy.dtype
210209
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
211210
if self.cond_hint is not None:
212211
del self.cond_hint
@@ -236,7 +235,7 @@ def get_control(self, x_noisy, t, cond, batched_number):
236235
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
237236

238237
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
239-
return self.control_merge(control, control_prev, output_dtype)
238+
return self.control_merge(control, control_prev, output_dtype=None)
240239

241240
def copy(self):
242241
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)

comfy/ldm/flux/controlnet.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,12 @@ def __init__(self, latent_input=False, image_model=None, dtype=None, device=None
2323
self.controlnet_blocks = nn.ModuleList([])
2424
for _ in range(self.params.depth):
2525
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
26-
# controlnet_block = zero_module(controlnet_block)
2726
self.controlnet_blocks.append(controlnet_block)
27+
28+
self.controlnet_single_blocks = nn.ModuleList([])
29+
for _ in range(self.params.depth_single_blocks):
30+
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
31+
2832
self.gradient_checkpointing = False
2933
self.latent_input = latent_input
3034
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
@@ -78,26 +82,39 @@ def forward_orig(
7882
ids = torch.cat((txt_ids, img_ids), dim=1)
7983
pe = self.pe_embedder(ids)
8084

81-
block_res_samples = ()
85+
controlnet_double = ()
86+
87+
for i in range(len(self.double_blocks)):
88+
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
89+
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
8290

83-
for block in self.double_blocks:
84-
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
85-
block_res_samples = block_res_samples + (img,)
91+
img = torch.cat((txt, img), 1)
8692

87-
controlnet_block_res_samples = ()
88-
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
89-
block_res_sample = controlnet_block(block_res_sample)
90-
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
93+
controlnet_single = ()
9194

95+
for i in range(len(self.single_blocks)):
96+
img = self.single_blocks[i](img, vec=vec, pe=pe)
97+
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
9298

93-
repeat = math.ceil(self.main_model_double / len(controlnet_block_res_samples))
99+
repeat = math.ceil(self.main_model_double / len(controlnet_double))
94100
if self.latent_input:
95101
out_input = ()
96-
for x in controlnet_block_res_samples:
102+
for x in controlnet_double:
97103
out_input += (x,) * repeat
98104
else:
99-
out_input = (controlnet_block_res_samples * repeat)
100-
return {"input": out_input[:self.main_model_double]}
105+
out_input = (controlnet_double * repeat)
106+
107+
out = {"input": out_input[:self.main_model_double]}
108+
if len(controlnet_single) > 0:
109+
repeat = math.ceil(self.main_model_single / len(controlnet_single))
110+
out_output = ()
111+
if self.latent_input:
112+
for x in controlnet_single:
113+
out_output += (x,) * repeat
114+
else:
115+
out_output = (controlnet_single * repeat)
116+
out["output"] = out_output[:self.main_model_single]
117+
return out
101118

102119
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
103120
patch_size = 2

0 commit comments

Comments
 (0)