@@ -94,6 +94,31 @@ def __call__(self, weight):
9494 return comfy .float .stochastic_rounding (comfy .lora .calculate_weight (self .patches [self .key ], weight .to (intermediate_dtype ), self .key , intermediate_dtype = intermediate_dtype ), weight .dtype , seed = string_to_seed (self .key ))
9595
9696 return comfy .lora .calculate_weight (self .patches [self .key ], weight , self .key , intermediate_dtype = intermediate_dtype )
97+
98+ def get_key_weight (model , key ):
99+ set_func = None
100+ convert_func = None
101+ op_keys = key .rsplit ('.' , 1 )
102+ if len (op_keys ) < 2 :
103+ weight = comfy .utils .get_attr (model , key )
104+ else :
105+ op = comfy .utils .get_attr (model , op_keys [0 ])
106+ try :
107+ set_func = getattr (op , "set_{}" .format (op_keys [1 ]))
108+ except AttributeError :
109+ pass
110+
111+ try :
112+ convert_func = getattr (op , "convert_{}" .format (op_keys [1 ]))
113+ except AttributeError :
114+ pass
115+
116+ weight = getattr (op , op_keys [1 ])
117+ if convert_func is not None :
118+ weight = comfy .utils .get_attr (model , key )
119+
120+ return weight , set_func , convert_func
121+
97122class ModelPatcher :
98123 def __init__ (self , model , load_device , offload_device , size = 0 , weight_inplace_update = False ):
99124 self .size = size
@@ -294,14 +319,16 @@ def get_key_patches(self, filter_prefix=None):
294319 if not k .startswith (filter_prefix ):
295320 continue
296321 bk = self .backup .get (k , None )
322+ weight , set_func , convert_func = get_key_weight (self .model , k )
297323 if bk is not None :
298324 weight = bk .weight
299- else :
300- weight = model_sd [k ]
325+ if convert_func is None :
326+ convert_func = lambda a , ** kwargs : a
327+
301328 if k in self .patches :
302- p [k ] = [weight ] + self .patches [k ]
329+ p [k ] = [( weight , convert_func ) ] + self .patches [k ]
303330 else :
304- p [k ] = (weight ,)
331+ p [k ] = [ (weight , convert_func )]
305332 return p
306333
307334 def model_state_dict (self , filter_prefix = None ):
@@ -317,8 +344,7 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
317344 if key not in self .patches :
318345 return
319346
320- weight = comfy .utils .get_attr (self .model , key )
321-
347+ weight , set_func , convert_func = get_key_weight (self .model , key )
322348 inplace_update = self .weight_inplace_update or inplace_update
323349
324350 if key not in self .backup :
@@ -328,12 +354,18 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
328354 temp_weight = comfy .model_management .cast_to_device (weight , device_to , torch .float32 , copy = True )
329355 else :
330356 temp_weight = weight .to (torch .float32 , copy = True )
357+ if convert_func is not None :
358+ temp_weight = convert_func (temp_weight , inplace = True )
359+
331360 out_weight = comfy .lora .calculate_weight (self .patches [key ], temp_weight , key )
332- out_weight = comfy .float .stochastic_rounding (out_weight , weight .dtype , seed = string_to_seed (key ))
333- if inplace_update :
334- comfy .utils .copy_to_param (self .model , key , out_weight )
361+ if set_func is None :
362+ out_weight = comfy .float .stochastic_rounding (out_weight , weight .dtype , seed = string_to_seed (key ))
363+ if inplace_update :
364+ comfy .utils .copy_to_param (self .model , key , out_weight )
365+ else :
366+ comfy .utils .set_attr_param (self .model , key , out_weight )
335367 else :
336- comfy . utils . set_attr_param ( self . model , key , out_weight )
368+ set_func ( out_weight , inplace_update = inplace_update , seed = string_to_seed ( key ) )
337369
338370 def load (self , device_to = None , lowvram_model_memory = 0 , force_patch_weights = False , full_load = False ):
339371 mem_counter = 0
0 commit comments