Skip to content

Commit 9230f65

Browse files
Fix some controlnets OOMing when loading.
1 parent 6ab1e6f commit 9230f65

1 file changed

Lines changed: 9 additions & 7 deletions

File tree

comfy/controlnet.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,8 @@ def controlnet_config(sd):
391391
else:
392392
operations = comfy.ops.disable_weight_init
393393

394-
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
394+
offload_device = comfy.model_management.unet_offload_device()
395+
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
395396

396397
def controlnet_load_state_dict(control_model, sd):
397398
missing, unexpected = control_model.load_state_dict(sd, strict=False)
@@ -405,12 +406,12 @@ def controlnet_load_state_dict(control_model, sd):
405406

406407
def load_controlnet_mmdit(sd):
407408
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
408-
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
409+
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
409410
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
410411
for k in sd:
411412
new_sd[k] = sd[k]
412413

413-
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
414+
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
414415
control_model = controlnet_load_state_dict(control_model, new_sd)
415416

416417
latent_format = comfy.latent_formats.SD3()
@@ -420,9 +421,9 @@ def load_controlnet_mmdit(sd):
420421

421422

422423
def load_controlnet_hunyuandit(controlnet_data):
423-
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
424+
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)
424425

425-
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
426+
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
426427
control_model = controlnet_load_state_dict(control_model, controlnet_data)
427428

428429
latent_format = comfy.latent_formats.SDXL()
@@ -431,8 +432,8 @@ def load_controlnet_hunyuandit(controlnet_data):
431432
return control
432433

433434
def load_controlnet_flux_xlabs(sd):
434-
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd)
435-
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
435+
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
436+
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
436437
control_model = controlnet_load_state_dict(control_model, sd)
437438
extra_conds = ['y', 'guidance']
438439
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
@@ -536,6 +537,7 @@ def load_controlnet(ckpt_path, model=None):
536537
if manual_cast_dtype is not None:
537538
controlnet_config["operations"] = comfy.ops.manual_cast
538539
controlnet_config["dtype"] = unet_dtype
540+
controlnet_config["device"] = comfy.model_management.unet_offload_device()
539541
controlnet_config.pop("out_channels")
540542
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
541543
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)

0 commit comments

Comments
 (0)