Skip to content

Commit 656759a

Browse files
authored
Merge branch 'comfyanonymous:master' into master
2 parents 12b10bf + 9230f65 commit 656759a

11 files changed

Lines changed: 295 additions & 232 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
230230

231231
Use ```--preview-method auto``` to enable previews.
232232

233-
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
233+
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth, taesdxl_decoder.pth, taesd3_decoder.pth and taef1_decoder.pth](https://github.com/madebyollin/taesd/) and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI and launch it with `--preview-method taesd` to enable high-quality previews.
234234

235235
## How to use TLS/SSL?
236236
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`

comfy/controlnet.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,8 @@ def controlnet_config(sd):
391391
else:
392392
operations = comfy.ops.disable_weight_init
393393

394-
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
394+
offload_device = comfy.model_management.unet_offload_device()
395+
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
395396

396397
def controlnet_load_state_dict(control_model, sd):
397398
missing, unexpected = control_model.load_state_dict(sd, strict=False)
@@ -405,12 +406,12 @@ def controlnet_load_state_dict(control_model, sd):
405406

406407
def load_controlnet_mmdit(sd):
407408
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
408-
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
409+
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
409410
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
410411
for k in sd:
411412
new_sd[k] = sd[k]
412413

413-
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
414+
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
414415
control_model = controlnet_load_state_dict(control_model, new_sd)
415416

416417
latent_format = comfy.latent_formats.SD3()
@@ -420,9 +421,9 @@ def load_controlnet_mmdit(sd):
420421

421422

422423
def load_controlnet_hunyuandit(controlnet_data):
423-
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
424+
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)
424425

425-
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
426+
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
426427
control_model = controlnet_load_state_dict(control_model, controlnet_data)
427428

428429
latent_format = comfy.latent_formats.SDXL()
@@ -431,8 +432,8 @@ def load_controlnet_hunyuandit(controlnet_data):
431432
return control
432433

433434
def load_controlnet_flux_xlabs(sd):
434-
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd)
435-
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
435+
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
436+
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
436437
control_model = controlnet_load_state_dict(control_model, sd)
437438
extra_conds = ['y', 'guidance']
438439
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
@@ -536,6 +537,7 @@ def load_controlnet(ckpt_path, model=None):
536537
if manual_cast_dtype is not None:
537538
controlnet_config["operations"] = comfy.ops.manual_cast
538539
controlnet_config["dtype"] = unet_dtype
540+
controlnet_config["device"] = comfy.model_management.unet_offload_device()
539541
controlnet_config.pop("out_channels")
540542
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
541543
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)

comfy/lora.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
"""
1818

1919
import comfy.utils
20+
import comfy.model_management
21+
import comfy.model_base
2022
import logging
23+
import torch
2124

2225
LORA_CLIP_MAP = {
2326
"mlp.fc1": "mlp_fc1",
@@ -322,3 +325,192 @@ def model_lora_keys_unet(model, key_map={}):
322325
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
323326

324327
return key_map
328+
329+
330+
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
331+
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
332+
lora_diff *= alpha
333+
weight_calc = weight + lora_diff.type(weight.dtype)
334+
weight_norm = (
335+
weight_calc.transpose(0, 1)
336+
.reshape(weight_calc.shape[1], -1)
337+
.norm(dim=1, keepdim=True)
338+
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
339+
.transpose(0, 1)
340+
)
341+
342+
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
343+
if strength != 1.0:
344+
weight_calc -= weight
345+
weight += strength * (weight_calc)
346+
else:
347+
weight[:] = weight_calc
348+
return weight
349+
350+
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
351+
for p in patches:
352+
strength = p[0]
353+
v = p[1]
354+
strength_model = p[2]
355+
offset = p[3]
356+
function = p[4]
357+
if function is None:
358+
function = lambda a: a
359+
360+
old_weight = None
361+
if offset is not None:
362+
old_weight = weight
363+
weight = weight.narrow(offset[0], offset[1], offset[2])
364+
365+
if strength_model != 1.0:
366+
weight *= strength_model
367+
368+
if isinstance(v, list):
369+
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
370+
371+
if len(v) == 1:
372+
patch_type = "diff"
373+
elif len(v) == 2:
374+
patch_type = v[0]
375+
v = v[1]
376+
377+
if patch_type == "diff":
378+
w1 = v[0]
379+
if strength != 0.0:
380+
if w1.shape != weight.shape:
381+
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
382+
else:
383+
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
384+
elif patch_type == "lora": #lora/locon
385+
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
386+
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
387+
dora_scale = v[4]
388+
if v[2] is not None:
389+
alpha = v[2] / mat2.shape[0]
390+
else:
391+
alpha = 1.0
392+
393+
if v[3] is not None:
394+
#locon mid weights, hopefully the math is fine because I didn't properly test it
395+
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
396+
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
397+
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
398+
try:
399+
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
400+
if dora_scale is not None:
401+
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
402+
else:
403+
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
404+
except Exception as e:
405+
logging.error("ERROR {} {} {}".format(patch_type, key, e))
406+
elif patch_type == "lokr":
407+
w1 = v[0]
408+
w2 = v[1]
409+
w1_a = v[3]
410+
w1_b = v[4]
411+
w2_a = v[5]
412+
w2_b = v[6]
413+
t2 = v[7]
414+
dora_scale = v[8]
415+
dim = None
416+
417+
if w1 is None:
418+
dim = w1_b.shape[0]
419+
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
420+
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
421+
else:
422+
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
423+
424+
if w2 is None:
425+
dim = w2_b.shape[0]
426+
if t2 is None:
427+
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
428+
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
429+
else:
430+
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
431+
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
432+
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
433+
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
434+
else:
435+
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
436+
437+
if len(w2.shape) == 4:
438+
w1 = w1.unsqueeze(2).unsqueeze(2)
439+
if v[2] is not None and dim is not None:
440+
alpha = v[2] / dim
441+
else:
442+
alpha = 1.0
443+
444+
try:
445+
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
446+
if dora_scale is not None:
447+
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
448+
else:
449+
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
450+
except Exception as e:
451+
logging.error("ERROR {} {} {}".format(patch_type, key, e))
452+
elif patch_type == "loha":
453+
w1a = v[0]
454+
w1b = v[1]
455+
if v[2] is not None:
456+
alpha = v[2] / w1b.shape[0]
457+
else:
458+
alpha = 1.0
459+
460+
w2a = v[3]
461+
w2b = v[4]
462+
dora_scale = v[7]
463+
if v[5] is not None: #cp decomposition
464+
t1 = v[5]
465+
t2 = v[6]
466+
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
467+
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
468+
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
469+
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
470+
471+
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
472+
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
473+
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
474+
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
475+
else:
476+
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
477+
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
478+
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
479+
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
480+
481+
try:
482+
lora_diff = (m1 * m2).reshape(weight.shape)
483+
if dora_scale is not None:
484+
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
485+
else:
486+
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
487+
except Exception as e:
488+
logging.error("ERROR {} {} {}".format(patch_type, key, e))
489+
elif patch_type == "glora":
490+
if v[4] is not None:
491+
alpha = v[4] / v[0].shape[0]
492+
else:
493+
alpha = 1.0
494+
495+
dora_scale = v[5]
496+
497+
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
498+
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
499+
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
500+
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
501+
502+
try:
503+
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
504+
if dora_scale is not None:
505+
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
506+
else:
507+
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
508+
except Exception as e:
509+
logging.error("ERROR {} {} {}".format(patch_type, key, e))
510+
else:
511+
logging.warning("patch type not recognized {} {}".format(patch_type, key))
512+
513+
if old_weight is not None:
514+
weight = old_weight
515+
516+
return weight

comfy/model_detection.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,15 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
472472
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
473473
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
474474
'use_temporal_attention': False, 'use_temporal_resblock': False}
475+
476+
SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
477+
'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
478+
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
479+
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
480+
'use_temporal_attention': False, 'use_temporal_resblock': False}
475481

476482

477-
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p]
483+
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
478484

479485
for unet_config in supported_models:
480486
matches = True

comfy/model_management.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,14 @@ class CPUState(Enum):
4444

4545
total_vram = 0
4646

47-
lowvram_available = True
4847
xpu_available = False
48+
try:
49+
torch_version = torch.version.__version__
50+
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
51+
except:
52+
pass
4953

54+
lowvram_available = True
5055
if args.deterministic:
5156
logging.info("Using deterministic algorithms for pytorch")
5257
torch.use_deterministic_algorithms(True, warn_only=True)
@@ -66,10 +71,10 @@ class CPUState(Enum):
6671

6772
try:
6873
import intel_extension_for_pytorch as ipex
69-
if torch.xpu.is_available():
70-
xpu_available = True
74+
_ = torch.xpu.device_count()
75+
xpu_available = torch.xpu.is_available()
7176
except:
72-
pass
77+
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
7378

7479
try:
7580
if torch.backends.mps.is_available():
@@ -189,7 +194,6 @@ def is_nvidia():
189194

190195
try:
191196
if is_nvidia():
192-
torch_version = torch.version.__version__
193197
if int(torch_version[0]) >= 2:
194198
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
195199
ENABLE_PYTORCH_ATTENTION = True
@@ -321,8 +325,9 @@ def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
321325
self.model_unload()
322326
raise e
323327

324-
if is_intel_xpu() and not args.disable_ipex_optimize:
325-
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
328+
if is_intel_xpu() and not args.disable_ipex_optimize and self.real_model is not None:
329+
with torch.no_grad():
330+
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
326331

327332
self.weights_loaded = True
328333
return self.real_model
@@ -561,7 +566,9 @@ def loaded_models(only_currently_used=False):
561566
def cleanup_models(keep_clone_weights_loaded=False):
562567
to_delete = []
563568
for i in range(len(current_loaded_models)):
564-
if sys.getrefcount(current_loaded_models[i].model) <= 2:
569+
#TODO: very fragile function needs improvement
570+
num_refs = sys.getrefcount(current_loaded_models[i].model)
571+
if num_refs <= 2:
565572
if not keep_clone_weights_loaded:
566573
to_delete = [i] + to_delete
567574
#TODO: find a less fragile way to do this.
@@ -884,7 +891,8 @@ def pytorch_attention_flash_attention():
884891
def force_upcast_attention_dtype():
885892
upcast = args.force_upcast_attention
886893
try:
887-
if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5
894+
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
895+
if (14, 5) <= macos_version < (14, 7): # black image bug on recent versions of MacOS
888896
upcast = True
889897
except:
890898
pass

0 commit comments

Comments
 (0)