@@ -324,6 +324,7 @@ def model_lora_keys_unet(model, key_map={}):
324324 to = diffusers_keys [k ]
325325 key_map ["transformer.{}" .format (k [:- len (".weight" )])] = to #simpletrainer and probably regular diffusers flux lora format
326326 key_map ["lycoris_{}" .format (k [:- len (".weight" )].replace ("." , "_" ))] = to #simpletrainer lycoris
327+ key_map ["lora_transformer_{}" .format (k [:- len (".weight" )].replace ("." , "_" ))] = to #onetrainer
327328
328329 return key_map
329330
@@ -527,20 +528,40 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
527528 except Exception as e :
528529 logging .error ("ERROR {} {} {}" .format (patch_type , key , e ))
529530 elif patch_type == "glora" :
530- if v [4 ] is not None :
531- alpha = v [4 ] / v [0 ].shape [0 ]
532- else :
533- alpha = 1.0
534-
535531 dora_scale = v [5 ]
536532
533+ old_glora = False
534+ if v [3 ].shape [1 ] == v [2 ].shape [0 ] == v [0 ].shape [0 ] == v [1 ].shape [1 ]:
535+ rank = v [0 ].shape [0 ]
536+ old_glora = True
537+
538+ if v [3 ].shape [0 ] == v [2 ].shape [1 ] == v [0 ].shape [1 ] == v [1 ].shape [0 ]:
539+ if old_glora and v [1 ].shape [0 ] == weight .shape [0 ] and weight .shape [0 ] == weight .shape [1 ]:
540+ pass
541+ else :
542+ old_glora = False
543+ rank = v [1 ].shape [0 ]
544+
537545 a1 = comfy .model_management .cast_to_device (v [0 ].flatten (start_dim = 1 ), weight .device , intermediate_dtype )
538546 a2 = comfy .model_management .cast_to_device (v [1 ].flatten (start_dim = 1 ), weight .device , intermediate_dtype )
539547 b1 = comfy .model_management .cast_to_device (v [2 ].flatten (start_dim = 1 ), weight .device , intermediate_dtype )
540548 b2 = comfy .model_management .cast_to_device (v [3 ].flatten (start_dim = 1 ), weight .device , intermediate_dtype )
541549
550+ if v [4 ] is not None :
551+ alpha = v [4 ] / rank
552+ else :
553+ alpha = 1.0
554+
542555 try :
543- lora_diff = (torch .mm (b2 , b1 ) + torch .mm (torch .mm (weight .flatten (start_dim = 1 ).to (dtype = intermediate_dtype ), a2 ), a1 )).reshape (weight .shape )
556+ if old_glora :
557+ lora_diff = (torch .mm (b2 , b1 ) + torch .mm (torch .mm (weight .flatten (start_dim = 1 ).to (dtype = intermediate_dtype ), a2 ), a1 )).reshape (weight .shape ) #old lycoris glora
558+ else :
559+ if weight .dim () > 2 :
560+ lora_diff = torch .einsum ("o i ..., i j -> o j ..." , torch .einsum ("o i ..., i j -> o j ..." , weight .to (dtype = intermediate_dtype ), a1 ), a2 ).reshape (weight .shape )
561+ else :
562+ lora_diff = torch .mm (torch .mm (weight .to (dtype = intermediate_dtype ), a1 ), a2 ).reshape (weight .shape )
563+ lora_diff += torch .mm (b1 , b2 ).reshape (weight .shape )
564+
544565 if dora_scale is not None :
545566 weight = function (weight_decompose (dora_scale , weight , lora_diff , alpha , strength , intermediate_dtype ))
546567 else :
0 commit comments