Skip to content

Commit 8265452

Browse files
authored
Merge branch 'comfyanonymous:master' into master
2 parents 468a99d + 66b0961 commit 8265452

21 files changed

Lines changed: 249 additions & 116 deletions

.github/workflows/windows_release_dependencies.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ on:
1212
description: 'extra dependencies'
1313
required: false
1414
type: string
15-
default: "\"numpy<2\""
15+
default: ""
1616
cu:
1717
description: 'cuda version'
1818
required: true

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ To run it on services like paperspace, kaggle or colab you can use my [Jupyter N
127127

128128
## Manual Install (Windows, Linux)
129129

130+
Note that some dependencies do not yet support python 3.13 so using 3.12 is recommended.
131+
130132
Git clone this repo.
131133

132134
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints

comfy/controlnet.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class StrengthType(Enum):
6060
LINEAR_UP = 2
6161

6262
class ControlBase:
63-
def __init__(self, device=None):
63+
def __init__(self):
6464
self.cond_hint_original = None
6565
self.cond_hint = None
6666
self.strength = 1.0
@@ -72,10 +72,6 @@ def __init__(self, device=None):
7272
self.compression_ratio = 8
7373
self.upscale_algorithm = 'nearest-exact'
7474
self.extra_args = {}
75-
76-
if device is None:
77-
device = comfy.model_management.get_torch_device()
78-
self.device = device
7975
self.previous_controlnet = None
8076
self.extra_conds = []
8177
self.strength_type = StrengthType.CONSTANT
@@ -185,8 +181,8 @@ def set_extra_arg(self, argument, value=None):
185181

186182

187183
class ControlNet(ControlBase):
188-
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
189-
super().__init__(device)
184+
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
185+
super().__init__()
190186
self.control_model = control_model
191187
self.load_device = load_device
192188
if control_model is not None:
@@ -242,7 +238,7 @@ def get_control(self, x_noisy, t, cond, batched_number):
242238
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
243239
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
244240

245-
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
241+
self.cond_hint = self.cond_hint.to(device=x_noisy.device, dtype=dtype)
246242
if x_noisy.shape[0] != self.cond_hint.shape[0]:
247243
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
248244

@@ -341,8 +337,8 @@ def forward(self, input):
341337

342338

343339
class ControlLora(ControlNet):
344-
def __init__(self, control_weights, global_average_pooling=False, device=None, model_options={}): #TODO? model_options
345-
ControlBase.__init__(self, device)
340+
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
341+
ControlBase.__init__(self)
346342
self.control_weights = control_weights
347343
self.global_average_pooling = global_average_pooling
348344
self.extra_conds += ["y"]
@@ -662,12 +658,15 @@ def load_controlnet(ckpt_path, model=None, model_options={}):
662658

663659
class T2IAdapter(ControlBase):
664660
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
665-
super().__init__(device)
661+
super().__init__()
666662
self.t2i_model = t2i_model
667663
self.channels_in = channels_in
668664
self.control_input = None
669665
self.compression_ratio = compression_ratio
670666
self.upscale_algorithm = upscale_algorithm
667+
if device is None:
668+
device = comfy.model_management.get_torch_device()
669+
self.device = device
671670

672671
def scale_image_to(self, width, height):
673672
unshuffle_amount = self.t2i_model.unshuffle_amount

comfy/float.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def manual_stochastic_round_to_float8(x, dtype, generator=None):
4141
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
4242
)
4343

44+
inf = torch.finfo(dtype)
45+
torch.clamp(sign, min=inf.min, max=inf.max, out=sign)
4446
return sign
4547

4648

comfy/ldm/modules/diffusionmodules/mmdit.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import torch
77
import torch.nn as nn
8-
from .. import attention
8+
from ..attention import optimized_attention
99
from einops import rearrange, repeat
1010
from .util import timestep_embedding
1111
import comfy.ops
@@ -266,8 +266,6 @@ def split_qkv(qkv, head_dim):
266266
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
267267
return qkv[0], qkv[1], qkv[2]
268268

269-
def optimized_attention(qkv, num_heads):
270-
return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)
271269

272270
class SelfAttention(nn.Module):
273271
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
@@ -326,9 +324,9 @@ def post_attention(self, x: torch.Tensor) -> torch.Tensor:
326324
return x
327325

328326
def forward(self, x: torch.Tensor) -> torch.Tensor:
329-
qkv = self.pre_attention(x)
327+
q, k, v = self.pre_attention(x)
330328
x = optimized_attention(
331-
qkv, num_heads=self.num_heads
329+
q, k, v, heads=self.num_heads
332330
)
333331
x = self.post_attention(x)
334332
return x
@@ -531,8 +529,8 @@ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
531529
assert not self.pre_only
532530
qkv, intermediates = self.pre_attention(x, c)
533531
attn = optimized_attention(
534-
qkv,
535-
num_heads=self.attn.num_heads,
532+
qkv[0], qkv[1], qkv[2],
533+
heads=self.attn.num_heads,
536534
)
537535
return self.post_attention(attn, *intermediates)
538536

@@ -557,8 +555,8 @@ def _block_mixing(context, x, context_block, x_block, c):
557555
qkv = tuple(o)
558556

559557
attn = optimized_attention(
560-
qkv,
561-
num_heads=x_block.attn.num_heads,
558+
qkv[0], qkv[1], qkv[2],
559+
heads=x_block.attn.num_heads,
562560
)
563561
context_attn, x_attn = (
564562
attn[:, : context_qkv[0].shape[1]],
@@ -642,7 +640,7 @@ def __init__(self, dim, heads=8, dim_head=64, dtype=None, device=None, operation
642640
def forward(self, x):
643641
qkv = self.qkv(x)
644642
q, k, v = split_qkv(qkv, self.dim_head)
645-
x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads)
643+
x = optimized_attention(q.reshape(q.shape[0], q.shape[1], -1), k, v, heads=self.heads)
646644
return self.proj(x)
647645

648646
class ContextProcessorBlock(nn.Module):

comfy/lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
415415
weight *= strength_model
416416

417417
if isinstance(v, list):
418-
v = (calculate_weight(v[1:], comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype, copy=True), key, intermediate_dtype=intermediate_dtype), )
418+
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
419419

420420
if len(v) == 1:
421421
patch_type = "diff"

comfy/model_base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
9696

9797
if not unet_config.get("disable_unet_model_creation", False):
9898
if model_config.custom_operations is None:
99-
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=model_config.optimizations.get("fp8", False))
99+
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
100+
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
100101
else:
101102
operations = model_config.custom_operations
102103
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
@@ -244,6 +245,10 @@ def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_
244245
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
245246

246247
unet_state_dict = self.diffusion_model.state_dict()
248+
249+
if self.model_config.scaled_fp8 is not None:
250+
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
251+
247252
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
248253

249254
if self.model_type == ModelType.V_PREDICTION:

comfy/model_detection.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,15 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
286286
return None
287287
model_config = model_config_from_unet_config(unet_config, state_dict)
288288
if model_config is None and use_base_if_no_match:
289-
return comfy.supported_models_base.BASE(unet_config)
290-
else:
291-
return model_config
289+
model_config = comfy.supported_models_base.BASE(unet_config)
290+
291+
scaled_fp8_weight = state_dict.get("{}scaled_fp8".format(unet_key_prefix), None)
292+
if scaled_fp8_weight is not None:
293+
model_config.scaled_fp8 = scaled_fp8_weight.dtype
294+
if model_config.scaled_fp8 == torch.float32:
295+
model_config.scaled_fp8 = torch.float8_e4m3fn
296+
297+
return model_config
292298

293299
def unet_prefix_from_state_dict(state_dict):
294300
candidates = ["model.diffusion_model.", #ldm/sgm models

comfy/model_management.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,9 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
647647
pass
648648

649649
if fp8_dtype is not None:
650+
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
651+
return fp8_dtype
652+
650653
free_model_memory = maximum_vram_for_weights(device)
651654
if model_params * 2 > free_model_memory:
652655
return fp8_dtype
@@ -840,27 +843,21 @@ def force_channels_last():
840843
#TODO
841844
return False
842845

843-
def cast_to_device(tensor, device, dtype, copy=False):
844-
device_supports_cast = False
845-
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
846-
device_supports_cast = True
847-
elif tensor.dtype == torch.bfloat16:
848-
if hasattr(device, 'type') and device.type.startswith("cuda"):
849-
device_supports_cast = True
850-
elif is_intel_xpu():
851-
device_supports_cast = True
846+
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
847+
if device is None or weight.device == device:
848+
if not copy:
849+
if dtype is None or weight.dtype == dtype:
850+
return weight
851+
return weight.to(dtype=dtype, copy=copy)
852852

853-
non_blocking = device_should_use_non_blocking(device)
853+
r = torch.empty_like(weight, dtype=dtype, device=device)
854+
r.copy_(weight, non_blocking=non_blocking)
855+
return r
856+
857+
def cast_to_device(tensor, device, dtype, copy=False):
858+
non_blocking = device_supports_non_blocking(device)
859+
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
854860

855-
if device_supports_cast:
856-
if copy:
857-
if tensor.device == device:
858-
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
859-
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
860-
else:
861-
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
862-
else:
863-
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
864861

865862
def xformers_enabled():
866863
global directml_enabled

comfy/model_patcher.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,31 @@ def __call__(self, weight):
9494
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
9595

9696
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
97+
98+
def get_key_weight(model, key):
99+
set_func = None
100+
convert_func = None
101+
op_keys = key.rsplit('.', 1)
102+
if len(op_keys) < 2:
103+
weight = comfy.utils.get_attr(model, key)
104+
else:
105+
op = comfy.utils.get_attr(model, op_keys[0])
106+
try:
107+
set_func = getattr(op, "set_{}".format(op_keys[1]))
108+
except AttributeError:
109+
pass
110+
111+
try:
112+
convert_func = getattr(op, "convert_{}".format(op_keys[1]))
113+
except AttributeError:
114+
pass
115+
116+
weight = getattr(op, op_keys[1])
117+
if convert_func is not None:
118+
weight = comfy.utils.get_attr(model, key)
119+
120+
return weight, set_func, convert_func
121+
97122
class ModelPatcher:
98123
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
99124
self.size = size
@@ -294,14 +319,16 @@ def get_key_patches(self, filter_prefix=None):
294319
if not k.startswith(filter_prefix):
295320
continue
296321
bk = self.backup.get(k, None)
322+
weight, set_func, convert_func = get_key_weight(self.model, k)
297323
if bk is not None:
298324
weight = bk.weight
299-
else:
300-
weight = model_sd[k]
325+
if convert_func is None:
326+
convert_func = lambda a, **kwargs: a
327+
301328
if k in self.patches:
302-
p[k] = [weight] + self.patches[k]
329+
p[k] = [(weight, convert_func)] + self.patches[k]
303330
else:
304-
p[k] = (weight,)
331+
p[k] = [(weight, convert_func)]
305332
return p
306333

307334
def model_state_dict(self, filter_prefix=None):
@@ -317,8 +344,7 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
317344
if key not in self.patches:
318345
return
319346

320-
weight = comfy.utils.get_attr(self.model, key)
321-
347+
weight, set_func, convert_func = get_key_weight(self.model, key)
322348
inplace_update = self.weight_inplace_update or inplace_update
323349

324350
if key not in self.backup:
@@ -328,12 +354,18 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
328354
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
329355
else:
330356
temp_weight = weight.to(torch.float32, copy=True)
357+
if convert_func is not None:
358+
temp_weight = convert_func(temp_weight, inplace=True)
359+
331360
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
332-
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
333-
if inplace_update:
334-
comfy.utils.copy_to_param(self.model, key, out_weight)
361+
if set_func is None:
362+
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
363+
if inplace_update:
364+
comfy.utils.copy_to_param(self.model, key, out_weight)
365+
else:
366+
comfy.utils.set_attr_param(self.model, key, out_weight)
335367
else:
336-
comfy.utils.set_attr_param(self.model, key, out_weight)
368+
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
337369

338370
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
339371
mem_counter = 0

0 commit comments

Comments
 (0)