Skip to content

Commit 483004d

Browse files
Support newer glora format.
1 parent 00a5d08 commit 483004d

1 file changed

Lines changed: 26 additions & 6 deletions

File tree

comfy/lora.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -528,20 +528,40 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
528528
except Exception as e:
529529
logging.error("ERROR {} {} {}".format(patch_type, key, e))
530530
elif patch_type == "glora":
531-
if v[4] is not None:
532-
alpha = v[4] / v[0].shape[0]
533-
else:
534-
alpha = 1.0
535-
536531
dora_scale = v[5]
537532

533+
old_glora = False
534+
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
535+
rank = v[0].shape[0]
536+
old_glora = True
537+
538+
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
539+
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
540+
pass
541+
else:
542+
old_glora = False
543+
rank = v[1].shape[0]
544+
538545
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
539546
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
540547
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
541548
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
542549

550+
if v[4] is not None:
551+
alpha = v[4] / rank
552+
else:
553+
alpha = 1.0
554+
543555
try:
544-
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape)
556+
if old_glora:
557+
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
558+
else:
559+
if weight.dim() > 2:
560+
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
561+
else:
562+
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
563+
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
564+
545565
if dora_scale is not None:
546566
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
547567
else:

0 commit comments

Comments
 (0)