Skip to content

Commit 10a79e9

Browse files
Implement model part of flux union controlnet.
1 parent ea3f39b commit 10a79e9

2 files changed

Lines changed: 19 additions & 3 deletions

File tree

comfy/controlnet.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,12 @@ def load_controlnet_flux_instantx(sd):
444444
for k in sd:
445445
new_sd[k] = sd[k]
446446

447-
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
447+
num_union_modes = 0
448+
union_cnet = "controlnet_mode_embedder.weight"
449+
if union_cnet in new_sd:
450+
num_union_modes = new_sd[union_cnet].shape[0]
451+
452+
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
448453
control_model = controlnet_load_state_dict(control_model, new_sd)
449454

450455
latent_format = comfy.latent_formats.Flux()

comfy/ldm/flux/controlnet.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
class ControlNetFlux(Flux):
17-
def __init__(self, latent_input=False, image_model=None, dtype=None, device=None, operations=None, **kwargs):
17+
def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs):
1818
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
1919

2020
self.main_model_double = 19
@@ -29,6 +29,11 @@ def __init__(self, latent_input=False, image_model=None, dtype=None, device=None
2929
for _ in range(self.params.depth_single_blocks):
3030
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
3131

32+
self.num_union_modes = num_union_modes
33+
self.controlnet_mode_embedder = None
34+
if self.num_union_modes > 0:
35+
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
36+
3237
self.gradient_checkpointing = False
3338
self.latent_input = latent_input
3439
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
@@ -61,6 +66,7 @@ def forward_orig(
6166
timesteps: Tensor,
6267
y: Tensor,
6368
guidance: Tensor = None,
69+
control_type: Tensor = None,
6470
) -> Tensor:
6571
if img.ndim != 3 or txt.ndim != 3:
6672
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -79,6 +85,11 @@ def forward_orig(
7985
vec = vec + self.vector_in(y)
8086
txt = self.txt_in(txt)
8187

88+
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
89+
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
90+
txt = torch.cat([control_cond, txt], dim=1)
91+
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
92+
8293
ids = torch.cat((txt_ids, img_ids), dim=1)
8394
pe = self.pe_embedder(ids)
8495

@@ -137,4 +148,4 @@ def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
137148
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
138149

139150
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
140-
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)
151+
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))

0 commit comments

Comments
 (0)