Skip to content

Commit b33cd61

Browse files
InstantX canny controlnet.
1 parent 34eda0f commit b33cd61

3 files changed

Lines changed: 63 additions & 27 deletions

File tree

comfy/controlnet.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
import comfy.ldm.cascade.controlnet
3535
import comfy.cldm.mmdit
3636
import comfy.ldm.hydit.controlnet
37-
import comfy.ldm.flux.controlnet_xlabs
37+
import comfy.ldm.flux.controlnet
3838

3939

4040
def broadcast_image_to(tensor, target_batch_size, batched_number):
@@ -433,12 +433,25 @@ def load_controlnet_hunyuandit(controlnet_data):
433433

434434
def load_controlnet_flux_xlabs(sd):
435435
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
436-
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
436+
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
437437
control_model = controlnet_load_state_dict(control_model, sd)
438438
extra_conds = ['y', 'guidance']
439439
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
440440
return control
441441

442+
def load_controlnet_flux_instantx(sd):
443+
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
444+
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
445+
for k in sd:
446+
new_sd[k] = sd[k]
447+
448+
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
449+
control_model = controlnet_load_state_dict(control_model, new_sd)
450+
451+
latent_format = comfy.latent_formats.Flux()
452+
extra_conds = ['y', 'guidance']
453+
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
454+
return control
442455

443456
def load_controlnet(ckpt_path, model=None):
444457
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
@@ -504,8 +517,10 @@ def load_controlnet(ckpt_path, model=None):
504517
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
505518
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
506519
return load_controlnet_flux_xlabs(controlnet_data)
507-
else:
520+
elif "pos_embed_input.proj.weight" in controlnet_data:
508521
return load_controlnet_mmdit(controlnet_data)
522+
elif "controlnet_x_embedder.weight" in controlnet_data:
523+
return load_controlnet_flux_instantx(controlnet_data)
509524

510525
pth_key = 'control_model.zero_convs.0.0.weight'
511526
pth = False
Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
22

33
import torch
4+
import math
45
from torch import Tensor, nn
56
from einops import rearrange, repeat
67

@@ -13,34 +14,38 @@
1314

1415

1516
class ControlNetFlux(Flux):
16-
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
17+
def __init__(self, latent_input=False, image_model=None, dtype=None, device=None, operations=None, **kwargs):
1718
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
1819

20+
self.main_model_double = 19
21+
self.main_model_single = 38
1922
# add ControlNet blocks
2023
self.controlnet_blocks = nn.ModuleList([])
2124
for _ in range(self.params.depth):
2225
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
2326
# controlnet_block = zero_module(controlnet_block)
2427
self.controlnet_blocks.append(controlnet_block)
25-
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
2628
self.gradient_checkpointing = False
27-
self.input_hint_block = nn.Sequential(
28-
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
29-
nn.SiLU(),
30-
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
31-
nn.SiLU(),
32-
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
33-
nn.SiLU(),
34-
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
35-
nn.SiLU(),
36-
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
37-
nn.SiLU(),
38-
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
39-
nn.SiLU(),
40-
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
41-
nn.SiLU(),
42-
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
43-
)
29+
self.latent_input = latent_input
30+
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
31+
if not self.latent_input:
32+
self.input_hint_block = nn.Sequential(
33+
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
34+
nn.SiLU(),
35+
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
36+
nn.SiLU(),
37+
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
38+
nn.SiLU(),
39+
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
40+
nn.SiLU(),
41+
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
42+
nn.SiLU(),
43+
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
44+
nn.SiLU(),
45+
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
46+
nn.SiLU(),
47+
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
48+
)
4449

4550
def forward_orig(
4651
self,
@@ -58,8 +63,10 @@ def forward_orig(
5863

5964
# running on sequences img
6065
img = self.img_in(img)
61-
controlnet_cond = self.input_hint_block(controlnet_cond)
62-
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
66+
if not self.latent_input:
67+
controlnet_cond = self.input_hint_block(controlnet_cond)
68+
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
69+
6370
controlnet_cond = self.pos_embed_input(controlnet_cond)
6471
img = img + controlnet_cond
6572
vec = self.time_in(timestep_embedding(timesteps, 256))
@@ -82,13 +89,25 @@ def forward_orig(
8289
block_res_sample = controlnet_block(block_res_sample)
8390
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
8491

85-
return {"input": (controlnet_block_res_samples * 10)[:19]}
92+
93+
repeat = math.ceil(self.main_model_double / len(controlnet_block_res_samples))
94+
if self.latent_input:
95+
out_input = ()
96+
for x in controlnet_block_res_samples:
97+
out_input += (x,) * repeat
98+
else:
99+
out_input = (controlnet_block_res_samples * repeat)
100+
return {"input": out_input[:self.main_model_double]}
86101

87102
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
88-
hint = hint * 2.0 - 1.0
103+
patch_size = 2
104+
if self.latent_input:
105+
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
106+
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
107+
else:
108+
hint = hint * 2.0 - 1.0
89109

90110
bs, c, h, w = x.shape
91-
patch_size = 2
92111
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
93112

94113
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)

comfy/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
528528
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
529529
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
530530
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
531+
("pos_embed_input.bias", "controlnet_x_embedder.bias"),
532+
("pos_embed_input.weight", "controlnet_x_embedder.weight"),
531533
}
532534

533535
for k in MAP_BASIC:

0 commit comments

Comments
 (0)