@@ -528,20 +528,40 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
528528 except Exception as e :
529529 logging .error ("ERROR {} {} {}" .format (patch_type , key , e ))
530530 elif patch_type == "glora" :
531- if v [4 ] is not None :
532- alpha = v [4 ] / v [0 ].shape [0 ]
533- else :
534- alpha = 1.0
535-
536531 dora_scale = v [5 ]
537532
533+ old_glora = False
534+ if v [3 ].shape [1 ] == v [2 ].shape [0 ] == v [0 ].shape [0 ] == v [1 ].shape [1 ]:
535+ rank = v [0 ].shape [0 ]
536+ old_glora = True
537+
538+ if v [3 ].shape [0 ] == v [2 ].shape [1 ] == v [0 ].shape [1 ] == v [1 ].shape [0 ]:
539+ if old_glora and v [1 ].shape [0 ] == weight .shape [0 ] and weight .shape [0 ] == weight .shape [1 ]:
540+ pass
541+ else :
542+ old_glora = False
543+ rank = v [1 ].shape [0 ]
544+
538545 a1 = comfy .model_management .cast_to_device (v [0 ].flatten (start_dim = 1 ), weight .device , intermediate_dtype )
539546 a2 = comfy .model_management .cast_to_device (v [1 ].flatten (start_dim = 1 ), weight .device , intermediate_dtype )
540547 b1 = comfy .model_management .cast_to_device (v [2 ].flatten (start_dim = 1 ), weight .device , intermediate_dtype )
541548 b2 = comfy .model_management .cast_to_device (v [3 ].flatten (start_dim = 1 ), weight .device , intermediate_dtype )
542549
550+ if v [4 ] is not None :
551+ alpha = v [4 ] / rank
552+ else :
553+ alpha = 1.0
554+
543555 try :
544- lora_diff = (torch .mm (b2 , b1 ) + torch .mm (torch .mm (weight .flatten (start_dim = 1 ).to (dtype = intermediate_dtype ), a2 ), a1 )).reshape (weight .shape )
556+ if old_glora :
557+ lora_diff = (torch .mm (b2 , b1 ) + torch .mm (torch .mm (weight .flatten (start_dim = 1 ).to (dtype = intermediate_dtype ), a2 ), a1 )).reshape (weight .shape ) #old lycoris glora
558+ else :
559+ if weight .dim () > 2 :
560+ lora_diff = torch .einsum ("o i ..., i j -> o j ..." , torch .einsum ("o i ..., i j -> o j ..." , weight .to (dtype = intermediate_dtype ), a1 ), a2 ).reshape (weight .shape )
561+ else :
562+ lora_diff = torch .mm (torch .mm (weight .to (dtype = intermediate_dtype ), a1 ), a2 ).reshape (weight .shape )
563+ lora_diff += torch .mm (b1 , b2 ).reshape (weight .shape )
564+
545565 if dora_scale is not None :
546566 weight = function (weight_decompose (dora_scale , weight , lora_diff , alpha , strength , intermediate_dtype ))
547567 else :
0 commit comments