Skip to content

Commit b1e27d1

Browse files
authored
Merge branch 'comfyanonymous:master' into master
2 parents 5dc405c + ec28cd9 commit b1e27d1

8 files changed

Lines changed: 57 additions & 25 deletions

File tree

.ci/windows_base_files/README_VERY_IMPORTANT.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ run_cpu.bat
1414

1515
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
1616

17-
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
17+
You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
1818

1919

2020
RECOMMENDED WAY TO UPDATE:

comfy/controlnet.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def control_merge(self, control, control_prev, output_dtype):
148148
elif self.strength_type == StrengthType.LINEAR_UP:
149149
x *= (self.strength ** float(len(control_output) - i))
150150

151-
if x.dtype != output_dtype:
151+
if output_dtype is not None and x.dtype != output_dtype:
152152
x = x.to(output_dtype)
153153

154154
out[key].append(x)
@@ -206,7 +206,6 @@ def get_control(self, x_noisy, t, cond, batched_number):
206206
if self.manual_cast_dtype is not None:
207207
dtype = self.manual_cast_dtype
208208

209-
output_dtype = x_noisy.dtype
210209
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
211210
if self.cond_hint is not None:
212211
del self.cond_hint
@@ -236,7 +235,7 @@ def get_control(self, x_noisy, t, cond, batched_number):
236235
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
237236

238237
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
239-
return self.control_merge(control, control_prev, output_dtype)
238+
return self.control_merge(control, control_prev, output_dtype=None)
240239

241240
def copy(self):
242241
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
@@ -445,7 +444,12 @@ def load_controlnet_flux_instantx(sd):
445444
for k in sd:
446445
new_sd[k] = sd[k]
447446

448-
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
447+
num_union_modes = 0
448+
union_cnet = "controlnet_mode_embedder.weight"
449+
if union_cnet in new_sd:
450+
num_union_modes = new_sd[union_cnet].shape[0]
451+
452+
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
449453
control_model = controlnet_load_state_dict(control_model, new_sd)
450454

451455
latent_format = comfy.latent_formats.Flux()

comfy/ldm/flux/controlnet.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
class ControlNetFlux(Flux):
17-
def __init__(self, latent_input=False, image_model=None, dtype=None, device=None, operations=None, **kwargs):
17+
def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs):
1818
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
1919

2020
self.main_model_double = 19
@@ -23,8 +23,17 @@ def __init__(self, latent_input=False, image_model=None, dtype=None, device=None
2323
self.controlnet_blocks = nn.ModuleList([])
2424
for _ in range(self.params.depth):
2525
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
26-
# controlnet_block = zero_module(controlnet_block)
2726
self.controlnet_blocks.append(controlnet_block)
27+
28+
self.controlnet_single_blocks = nn.ModuleList([])
29+
for _ in range(self.params.depth_single_blocks):
30+
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
31+
32+
self.num_union_modes = num_union_modes
33+
self.controlnet_mode_embedder = None
34+
if self.num_union_modes > 0:
35+
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
36+
2837
self.gradient_checkpointing = False
2938
self.latent_input = latent_input
3039
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
@@ -57,6 +66,7 @@ def forward_orig(
5766
timesteps: Tensor,
5867
y: Tensor,
5968
guidance: Tensor = None,
69+
control_type: Tensor = None,
6070
) -> Tensor:
6171
if img.ndim != 3 or txt.ndim != 3:
6272
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -75,29 +85,47 @@ def forward_orig(
7585
vec = vec + self.vector_in(y)
7686
txt = self.txt_in(txt)
7787

88+
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
89+
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
90+
txt = torch.cat([control_cond, txt], dim=1)
91+
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
92+
7893
ids = torch.cat((txt_ids, img_ids), dim=1)
7994
pe = self.pe_embedder(ids)
8095

81-
block_res_samples = ()
96+
controlnet_double = ()
97+
98+
for i in range(len(self.double_blocks)):
99+
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
100+
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
82101

83-
for block in self.double_blocks:
84-
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
85-
block_res_samples = block_res_samples + (img,)
102+
img = torch.cat((txt, img), 1)
86103

87-
controlnet_block_res_samples = ()
88-
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
89-
block_res_sample = controlnet_block(block_res_sample)
90-
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
104+
controlnet_single = ()
91105

106+
for i in range(len(self.single_blocks)):
107+
img = self.single_blocks[i](img, vec=vec, pe=pe)
108+
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
92109

93-
repeat = math.ceil(self.main_model_double / len(controlnet_block_res_samples))
110+
repeat = math.ceil(self.main_model_double / len(controlnet_double))
94111
if self.latent_input:
95112
out_input = ()
96-
for x in controlnet_block_res_samples:
113+
for x in controlnet_double:
97114
out_input += (x,) * repeat
98115
else:
99-
out_input = (controlnet_block_res_samples * repeat)
100-
return {"input": out_input[:self.main_model_double]}
116+
out_input = (controlnet_double * repeat)
117+
118+
out = {"input": out_input[:self.main_model_double]}
119+
if len(controlnet_single) > 0:
120+
repeat = math.ceil(self.main_model_single / len(controlnet_single))
121+
out_output = ()
122+
if self.latent_input:
123+
for x in controlnet_single:
124+
out_output += (x,) * repeat
125+
else:
126+
out_output = (controlnet_single * repeat)
127+
out["output"] = out_output[:self.main_model_single]
128+
return out
101129

102130
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
103131
patch_size = 2
@@ -120,4 +148,4 @@ def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
120148
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
121149

122150
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
123-
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)
151+
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))

comfy/lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
540540
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
541541

542542
try:
543-
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
543+
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape)
544544
if dora_scale is not None:
545545
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
546546
else:

notebooks/comfyui_colab.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
"#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n",
8080
"\n",
8181
"# SD1.5\n",
82-
"!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n",
82+
"!wget -c https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/resolve/main/v1-5-pruned-emaonly-fp16.safetensors -P ./models/checkpoints/\n",
8383
"\n",
8484
"# SD2\n",
8585
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n",

script_examples/basic_api_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
"4": {
4444
"class_type": "CheckpointLoaderSimple",
4545
"inputs": {
46-
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
46+
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
4747
}
4848
},
4949
"5": {

script_examples/websockets_api_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def get_images(ws, prompt):
8484
"4": {
8585
"class_type": "CheckpointLoaderSimple",
8686
"inputs": {
87-
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
87+
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
8888
}
8989
},
9090
"5": {

script_examples/websockets_api_example_ws_images.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def get_images(ws, prompt):
8181
"4": {
8282
"class_type": "CheckpointLoaderSimple",
8383
"inputs": {
84-
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
84+
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
8585
}
8686
},
8787
"5": {

0 commit comments

Comments
 (0)