Skip to content

Commit 5dc405c

Browse files
authored
Merge branch 'comfyanonymous:master' into master
2 parents 49a4357 + b33cd61 commit 5dc405c

10 files changed

Lines changed: 107 additions & 81 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<div align="center">
22

33
# ComfyUI
4-
**The most powerful and modular stable diffusion GUI and backend.**
4+
**The most powerful and modular diffusion model GUI and backend.**
55

66

77
[![Website][website-shield]][website-url]

comfy/controlnet.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
import comfy.ldm.cascade.controlnet
3535
import comfy.cldm.mmdit
3636
import comfy.ldm.hydit.controlnet
37-
import comfy.ldm.flux.controlnet_xlabs
37+
import comfy.ldm.flux.controlnet
3838

3939

4040
def broadcast_image_to(tensor, target_batch_size, batched_number):
@@ -433,12 +433,25 @@ def load_controlnet_hunyuandit(controlnet_data):
433433

434434
def load_controlnet_flux_xlabs(sd):
435435
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
436-
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
436+
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
437437
control_model = controlnet_load_state_dict(control_model, sd)
438438
extra_conds = ['y', 'guidance']
439439
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
440440
return control
441441

442+
def load_controlnet_flux_instantx(sd):
443+
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
444+
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
445+
for k in sd:
446+
new_sd[k] = sd[k]
447+
448+
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
449+
control_model = controlnet_load_state_dict(control_model, new_sd)
450+
451+
latent_format = comfy.latent_formats.Flux()
452+
extra_conds = ['y', 'guidance']
453+
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
454+
return control
442455

443456
def load_controlnet(ckpt_path, model=None):
444457
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
@@ -504,8 +517,10 @@ def load_controlnet(ckpt_path, model=None):
504517
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
505518
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
506519
return load_controlnet_flux_xlabs(controlnet_data)
507-
else:
520+
elif "pos_embed_input.proj.weight" in controlnet_data:
508521
return load_controlnet_mmdit(controlnet_data)
522+
elif "controlnet_x_embedder.weight" in controlnet_data:
523+
return load_controlnet_flux_instantx(controlnet_data)
509524

510525
pth_key = 'control_model.zero_convs.0.0.weight'
511526
pth = False

comfy/ldm/common_dit.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,21 @@
11
import torch
2+
import comfy.ops
23

34
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
45
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
56
padding_mode = "reflect"
67
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
78
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
89
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
10+
11+
try:
12+
rms_norm_torch = torch.nn.functional.rms_norm
13+
except:
14+
rms_norm_torch = None
15+
16+
def rms_norm(x, weight, eps=1e-6):
17+
if rms_norm_torch is not None:
18+
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
19+
else:
20+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
21+
return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
22

33
import torch
4+
import math
45
from torch import Tensor, nn
56
from einops import rearrange, repeat
67

@@ -13,34 +14,38 @@
1314

1415

1516
class ControlNetFlux(Flux):
16-
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
17+
def __init__(self, latent_input=False, image_model=None, dtype=None, device=None, operations=None, **kwargs):
1718
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
1819

20+
self.main_model_double = 19
21+
self.main_model_single = 38
1922
# add ControlNet blocks
2023
self.controlnet_blocks = nn.ModuleList([])
2124
for _ in range(self.params.depth):
2225
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
2326
# controlnet_block = zero_module(controlnet_block)
2427
self.controlnet_blocks.append(controlnet_block)
25-
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
2628
self.gradient_checkpointing = False
27-
self.input_hint_block = nn.Sequential(
28-
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
29-
nn.SiLU(),
30-
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
31-
nn.SiLU(),
32-
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
33-
nn.SiLU(),
34-
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
35-
nn.SiLU(),
36-
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
37-
nn.SiLU(),
38-
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
39-
nn.SiLU(),
40-
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
41-
nn.SiLU(),
42-
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
43-
)
29+
self.latent_input = latent_input
30+
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
31+
if not self.latent_input:
32+
self.input_hint_block = nn.Sequential(
33+
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
34+
nn.SiLU(),
35+
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
36+
nn.SiLU(),
37+
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
38+
nn.SiLU(),
39+
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
40+
nn.SiLU(),
41+
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
42+
nn.SiLU(),
43+
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
44+
nn.SiLU(),
45+
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
46+
nn.SiLU(),
47+
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
48+
)
4449

4550
def forward_orig(
4651
self,
@@ -58,8 +63,10 @@ def forward_orig(
5863

5964
# running on sequences img
6065
img = self.img_in(img)
61-
controlnet_cond = self.input_hint_block(controlnet_cond)
62-
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
66+
if not self.latent_input:
67+
controlnet_cond = self.input_hint_block(controlnet_cond)
68+
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
69+
6370
controlnet_cond = self.pos_embed_input(controlnet_cond)
6471
img = img + controlnet_cond
6572
vec = self.time_in(timestep_embedding(timesteps, 256))
@@ -82,13 +89,25 @@ def forward_orig(
8289
block_res_sample = controlnet_block(block_res_sample)
8390
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
8491

85-
return {"input": (controlnet_block_res_samples * 10)[:19]}
92+
93+
repeat = math.ceil(self.main_model_double / len(controlnet_block_res_samples))
94+
if self.latent_input:
95+
out_input = ()
96+
for x in controlnet_block_res_samples:
97+
out_input += (x,) * repeat
98+
else:
99+
out_input = (controlnet_block_res_samples * repeat)
100+
return {"input": out_input[:self.main_model_double]}
86101

87102
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
88-
hint = hint * 2.0 - 1.0
103+
patch_size = 2
104+
if self.latent_input:
105+
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
106+
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
107+
else:
108+
hint = hint * 2.0 - 1.0
89109

90110
bs, c, h, w = x.shape
91-
patch_size = 2
92111
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
93112

94113
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)

comfy/ldm/flux/layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .math import attention, rope
88
import comfy.ops
9+
import comfy.ldm.common_dit
910

1011

1112
class EmbedND(nn.Module):
@@ -63,8 +64,7 @@ def __init__(self, dim: int, dtype=None, device=None, operations=None):
6364
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
6465

6566
def forward(self, x: Tensor):
66-
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
67-
return (x * rrms) * comfy.ops.cast_to(self.scale, dtype=x.dtype, device=x.device)
67+
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
6868

6969

7070
class QKNorm(torch.nn.Module):

comfy/ldm/modules/diffusionmodules/mmdit.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -355,29 +355,9 @@ def __init__(
355355
else:
356356
self.register_parameter("weight", None)
357357

358-
def _norm(self, x):
359-
"""
360-
Apply the RMSNorm normalization to the input tensor.
361-
Args:
362-
x (torch.Tensor): The input tensor.
363-
Returns:
364-
torch.Tensor: The normalized tensor.
365-
"""
366-
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
367-
368358
def forward(self, x):
369-
"""
370-
Forward pass through the RMSNorm layer.
371-
Args:
372-
x (torch.Tensor): The input tensor.
373-
Returns:
374-
torch.Tensor: The output tensor after applying RMSNorm.
375-
"""
376-
x = self._norm(x)
377-
if self.learnable_scale:
378-
return x * self.weight.to(device=x.device, dtype=x.dtype)
379-
else:
380-
return x
359+
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
360+
381361

382362

383363
class SwiGLUFeedForward(nn.Module):

comfy/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
528528
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
529529
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
530530
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
531+
("pos_embed_input.bias", "controlnet_x_embedder.bias"),
532+
("pos_embed_input.weight", "controlnet_x_embedder.weight"),
531533
}
532534

533535
for k in MAP_BASIC:

script_examples/websockets_api_example.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,14 @@ def get_images(ws, prompt):
4141
continue #previews are binary data
4242

4343
history = get_history(prompt_id)[prompt_id]
44-
for o in history['outputs']:
45-
for node_id in history['outputs']:
46-
node_output = history['outputs'][node_id]
47-
if 'images' in node_output:
48-
images_output = []
49-
for image in node_output['images']:
50-
image_data = get_image(image['filename'], image['subfolder'], image['type'])
51-
images_output.append(image_data)
52-
output_images[node_id] = images_output
44+
for node_id in history['outputs']:
45+
node_output = history['outputs'][node_id]
46+
images_output = []
47+
if 'images' in node_output:
48+
for image in node_output['images']:
49+
image_data = get_image(image['filename'], image['subfolder'], image['type'])
50+
images_output.append(image_data)
51+
output_images[node_id] = images_output
5352

5453
return output_images
5554

tests/inference/test_execution.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -95,17 +95,16 @@ def run(self, graph):
9595
pass # Probably want to store this off for testing
9696

9797
history = self.get_history(prompt_id)[prompt_id]
98-
for o in history['outputs']:
99-
for node_id in history['outputs']:
100-
node_output = history['outputs'][node_id]
101-
result.outputs[node_id] = node_output
102-
if 'images' in node_output:
103-
images_output = []
104-
for image in node_output['images']:
105-
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
106-
image_obj = Image.open(BytesIO(image_data))
107-
images_output.append(image_obj)
108-
node_output['image_objects'] = images_output
98+
for node_id in history['outputs']:
99+
node_output = history['outputs'][node_id]
100+
result.outputs[node_id] = node_output
101+
images_output = []
102+
if 'images' in node_output:
103+
for image in node_output['images']:
104+
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
105+
image_obj = Image.open(BytesIO(image_data))
106+
images_output.append(image_obj)
107+
node_output['image_objects'] = images_output
109108

110109
return result
111110

tests/inference/test_inference.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,14 @@ def get_images(self, graph, save=True):
109109
continue #previews are binary data
110110

111111
history = self.get_history(prompt_id)[prompt_id]
112-
for o in history['outputs']:
113-
for node_id in history['outputs']:
114-
node_output = history['outputs'][node_id]
115-
if 'images' in node_output:
116-
images_output = []
117-
for image in node_output['images']:
118-
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
119-
images_output.append(image_data)
120-
output_images[node_id] = images_output
112+
for node_id in history['outputs']:
113+
node_output = history['outputs'][node_id]
114+
images_output = []
115+
if 'images' in node_output:
116+
for image in node_output['images']:
117+
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
118+
images_output.append(image_data)
119+
output_images[node_id] = images_output
121120

122121
return output_images
123122

0 commit comments

Comments
 (0)