@@ -391,7 +391,8 @@ def controlnet_config(sd):
391391 else :
392392 operations = comfy .ops .disable_weight_init
393393
394- return model_config , operations , load_device , unet_dtype , manual_cast_dtype
394+ offload_device = comfy .model_management .unet_offload_device ()
395+ return model_config , operations , load_device , unet_dtype , manual_cast_dtype , offload_device
395396
396397def controlnet_load_state_dict (control_model , sd ):
397398 missing , unexpected = control_model .load_state_dict (sd , strict = False )
@@ -405,12 +406,12 @@ def controlnet_load_state_dict(control_model, sd):
405406
406407def load_controlnet_mmdit (sd ):
407408 new_sd = comfy .model_detection .convert_diffusers_mmdit (sd , "" )
408- model_config , operations , load_device , unet_dtype , manual_cast_dtype = controlnet_config (new_sd )
409+ model_config , operations , load_device , unet_dtype , manual_cast_dtype , offload_device = controlnet_config (new_sd )
409410 num_blocks = comfy .model_detection .count_blocks (new_sd , 'joint_blocks.{}.' )
410411 for k in sd :
411412 new_sd [k ] = sd [k ]
412413
413- control_model = comfy .cldm .mmdit .ControlNet (num_blocks = num_blocks , operations = operations , device = load_device , dtype = unet_dtype , ** model_config .unet_config )
414+ control_model = comfy .cldm .mmdit .ControlNet (num_blocks = num_blocks , operations = operations , device = offload_device , dtype = unet_dtype , ** model_config .unet_config )
414415 control_model = controlnet_load_state_dict (control_model , new_sd )
415416
416417 latent_format = comfy .latent_formats .SD3 ()
@@ -420,9 +421,9 @@ def load_controlnet_mmdit(sd):
420421
421422
422423def load_controlnet_hunyuandit (controlnet_data ):
423- model_config , operations , load_device , unet_dtype , manual_cast_dtype = controlnet_config (controlnet_data )
424+ model_config , operations , load_device , unet_dtype , manual_cast_dtype , offload_device = controlnet_config (controlnet_data )
424425
425- control_model = comfy .ldm .hydit .controlnet .HunYuanControlNet (operations = operations , device = load_device , dtype = unet_dtype )
426+ control_model = comfy .ldm .hydit .controlnet .HunYuanControlNet (operations = operations , device = offload_device , dtype = unet_dtype )
426427 control_model = controlnet_load_state_dict (control_model , controlnet_data )
427428
428429 latent_format = comfy .latent_formats .SDXL ()
@@ -431,8 +432,8 @@ def load_controlnet_hunyuandit(controlnet_data):
431432 return control
432433
433434def load_controlnet_flux_xlabs (sd ):
434- model_config , operations , load_device , unet_dtype , manual_cast_dtype = controlnet_config (sd )
435- control_model = comfy .ldm .flux .controlnet_xlabs .ControlNetFlux (operations = operations , device = load_device , dtype = unet_dtype , ** model_config .unet_config )
435+ model_config , operations , load_device , unet_dtype , manual_cast_dtype , offload_device = controlnet_config (sd )
436+ control_model = comfy .ldm .flux .controlnet_xlabs .ControlNetFlux (operations = operations , device = offload_device , dtype = unet_dtype , ** model_config .unet_config )
436437 control_model = controlnet_load_state_dict (control_model , sd )
437438 extra_conds = ['y' , 'guidance' ]
438439 control = ControlNet (control_model , load_device = load_device , manual_cast_dtype = manual_cast_dtype , extra_conds = extra_conds )
@@ -536,6 +537,7 @@ def load_controlnet(ckpt_path, model=None):
536537 if manual_cast_dtype is not None :
537538 controlnet_config ["operations" ] = comfy .ops .manual_cast
538539 controlnet_config ["dtype" ] = unet_dtype
540+ controlnet_config ["device" ] = comfy .model_management .unet_offload_device ()
539541 controlnet_config .pop ("out_channels" )
540542 controlnet_config ["hint_channels" ] = controlnet_data ["{}input_hint_block.0.weight" .format (prefix )].shape [1 ]
541543 control_model = comfy .cldm .cldm .ControlNet (** controlnet_config )
0 commit comments