@@ -79,13 +79,19 @@ def __init__(self, device=None):
7979 self .previous_controlnet = None
8080 self .extra_conds = []
8181 self .strength_type = StrengthType .CONSTANT
82+ self .concat_mask = False
83+ self .extra_concat_orig = []
84+ self .extra_concat = None
8285
83- def set_cond_hint (self , cond_hint , strength = 1.0 , timestep_percent_range = (0.0 , 1.0 ), vae = None ):
86+ def set_cond_hint (self , cond_hint , strength = 1.0 , timestep_percent_range = (0.0 , 1.0 ), vae = None , extra_concat = [] ):
8487 self .cond_hint_original = cond_hint
8588 self .strength = strength
8689 self .timestep_percent_range = timestep_percent_range
8790 if self .latent_format is not None :
8891 self .vae = vae
92+ self .extra_concat_orig = extra_concat .copy ()
93+ if self .concat_mask and len (self .extra_concat_orig ) == 0 :
94+ self .extra_concat_orig .append (torch .tensor ([[[[1.0 ]]]]))
8995 return self
9096
9197 def pre_run (self , model , percent_to_timestep_function ):
@@ -100,9 +106,9 @@ def set_previous_controlnet(self, controlnet):
100106 def cleanup (self ):
101107 if self .previous_controlnet is not None :
102108 self .previous_controlnet .cleanup ()
103- if self . cond_hint is not None :
104- del self .cond_hint
105- self .cond_hint = None
109+
110+ self .cond_hint = None
111+ self .extra_concat = None
106112 self .timestep_range = None
107113
108114 def get_models (self ):
@@ -123,6 +129,8 @@ def copy_to(self, c):
123129 c .vae = self .vae
124130 c .extra_conds = self .extra_conds .copy ()
125131 c .strength_type = self .strength_type
132+ c .concat_mask = self .concat_mask
133+ c .extra_concat_orig = self .extra_concat_orig .copy ()
126134
127135 def inference_memory_requirements (self , dtype ):
128136 if self .previous_controlnet is not None :
@@ -175,7 +183,7 @@ def set_extra_arg(self, argument, value=None):
175183
176184
177185class ControlNet (ControlBase ):
178- def __init__ (self , control_model = None , global_average_pooling = False , compression_ratio = 8 , latent_format = None , device = None , load_device = None , manual_cast_dtype = None , extra_conds = ["y" ], strength_type = StrengthType .CONSTANT ):
186+ def __init__ (self , control_model = None , global_average_pooling = False , compression_ratio = 8 , latent_format = None , device = None , load_device = None , manual_cast_dtype = None , extra_conds = ["y" ], strength_type = StrengthType .CONSTANT , concat_mask = False ):
179187 super ().__init__ (device )
180188 self .control_model = control_model
181189 self .load_device = load_device
@@ -189,6 +197,7 @@ def __init__(self, control_model=None, global_average_pooling=False, compression
189197 self .latent_format = latent_format
190198 self .extra_conds += extra_conds
191199 self .strength_type = strength_type
200+ self .concat_mask = concat_mask
192201
193202 def get_control (self , x_noisy , t , cond , batched_number ):
194203 control_prev = None
@@ -220,6 +229,13 @@ def get_control(self, x_noisy, t, cond, batched_number):
220229 comfy .model_management .load_models_gpu (loaded_models )
221230 if self .latent_format is not None :
222231 self .cond_hint = self .latent_format .process_in (self .cond_hint )
232+ if len (self .extra_concat_orig ) > 0 :
233+ to_concat = []
234+ for c in self .extra_concat_orig :
235+ c = comfy .utils .common_upscale (c , self .cond_hint .shape [3 ], self .cond_hint .shape [2 ], self .upscale_algorithm , "center" )
236+ to_concat .append (comfy .utils .repeat_to_batch_size (c , self .cond_hint .shape [0 ]))
237+ self .cond_hint = torch .cat ([self .cond_hint ] + to_concat , dim = 1 )
238+
223239 self .cond_hint = self .cond_hint .to (device = self .device , dtype = dtype )
224240 if x_noisy .shape [0 ] != self .cond_hint .shape [0 ]:
225241 self .cond_hint = broadcast_image_to (self .cond_hint , x_noisy .shape [0 ], batched_number )
@@ -410,12 +426,17 @@ def load_controlnet_mmdit(sd):
410426 for k in sd :
411427 new_sd [k ] = sd [k ]
412428
413- control_model = comfy .cldm .mmdit .ControlNet (num_blocks = num_blocks , operations = operations , device = offload_device , dtype = unet_dtype , ** model_config .unet_config )
429+ concat_mask = False
430+ control_latent_channels = new_sd .get ("pos_embed_input.proj.weight" ).shape [1 ]
431+ if control_latent_channels == 17 : #inpaint controlnet
432+ concat_mask = True
433+
434+ control_model = comfy .cldm .mmdit .ControlNet (num_blocks = num_blocks , control_latent_channels = control_latent_channels , operations = operations , device = offload_device , dtype = unet_dtype , ** model_config .unet_config )
414435 control_model = controlnet_load_state_dict (control_model , new_sd )
415436
416437 latent_format = comfy .latent_formats .SD3 ()
417438 latent_format .shift_factor = 0 #SD3 controlnet weirdness
418- control = ControlNet (control_model , compression_ratio = 1 , latent_format = latent_format , load_device = load_device , manual_cast_dtype = manual_cast_dtype )
439+ control = ControlNet (control_model , compression_ratio = 1 , latent_format = latent_format , concat_mask = concat_mask , load_device = load_device , manual_cast_dtype = manual_cast_dtype )
419440 return control
420441
421442
@@ -450,13 +471,16 @@ def load_controlnet_flux_instantx(sd):
450471 num_union_modes = new_sd [union_cnet ].shape [0 ]
451472
452473 control_latent_channels = new_sd .get ("pos_embed_input.weight" ).shape [1 ] // 4
474+ concat_mask = False
475+ if control_latent_channels == 17 :
476+ concat_mask = True
453477
454478 control_model = comfy .ldm .flux .controlnet .ControlNetFlux (latent_input = True , num_union_modes = num_union_modes , control_latent_channels = control_latent_channels , operations = operations , device = offload_device , dtype = unet_dtype , ** model_config .unet_config )
455479 control_model = controlnet_load_state_dict (control_model , new_sd )
456480
457481 latent_format = comfy .latent_formats .Flux ()
458482 extra_conds = ['y' , 'guidance' ]
459- control = ControlNet (control_model , compression_ratio = 1 , latent_format = latent_format , load_device = load_device , manual_cast_dtype = manual_cast_dtype , extra_conds = extra_conds )
483+ control = ControlNet (control_model , compression_ratio = 1 , latent_format = latent_format , concat_mask = concat_mask , load_device = load_device , manual_cast_dtype = manual_cast_dtype , extra_conds = extra_conds )
460484 return control
461485
462486def convert_mistoline (sd ):
0 commit comments