@@ -60,7 +60,7 @@ class StrengthType(Enum):
6060 LINEAR_UP = 2
6161
6262class ControlBase :
63- def __init__ (self , device = None ):
63+ def __init__ (self ):
6464 self .cond_hint_original = None
6565 self .cond_hint = None
6666 self .strength = 1.0
@@ -72,10 +72,6 @@ def __init__(self, device=None):
7272 self .compression_ratio = 8
7373 self .upscale_algorithm = 'nearest-exact'
7474 self .extra_args = {}
75-
76- if device is None :
77- device = comfy .model_management .get_torch_device ()
78- self .device = device
7975 self .previous_controlnet = None
8076 self .extra_conds = []
8177 self .strength_type = StrengthType .CONSTANT
@@ -185,8 +181,8 @@ def set_extra_arg(self, argument, value=None):
185181
186182
187183class ControlNet (ControlBase ):
188- def __init__ (self , control_model = None , global_average_pooling = False , compression_ratio = 8 , latent_format = None , device = None , load_device = None , manual_cast_dtype = None , extra_conds = ["y" ], strength_type = StrengthType .CONSTANT , concat_mask = False ):
189- super ().__init__ (device )
184+ def __init__ (self , control_model = None , global_average_pooling = False , compression_ratio = 8 , latent_format = None , load_device = None , manual_cast_dtype = None , extra_conds = ["y" ], strength_type = StrengthType .CONSTANT , concat_mask = False ):
185+ super ().__init__ ()
190186 self .control_model = control_model
191187 self .load_device = load_device
192188 if control_model is not None :
@@ -242,7 +238,7 @@ def get_control(self, x_noisy, t, cond, batched_number):
242238 to_concat .append (comfy .utils .repeat_to_batch_size (c , self .cond_hint .shape [0 ]))
243239 self .cond_hint = torch .cat ([self .cond_hint ] + to_concat , dim = 1 )
244240
245- self .cond_hint = self .cond_hint .to (device = self .device , dtype = dtype )
241+ self .cond_hint = self .cond_hint .to (device = self .load_device , dtype = dtype )
246242 if x_noisy .shape [0 ] != self .cond_hint .shape [0 ]:
247243 self .cond_hint = broadcast_image_to (self .cond_hint , x_noisy .shape [0 ], batched_number )
248244
@@ -341,8 +337,8 @@ def forward(self, input):
341337
342338
343339class ControlLora (ControlNet ):
344- def __init__ (self , control_weights , global_average_pooling = False , device = None , model_options = {}): #TODO? model_options
345- ControlBase .__init__ (self , device )
340+ def __init__ (self , control_weights , global_average_pooling = False , model_options = {}): #TODO? model_options
341+ ControlBase .__init__ (self )
346342 self .control_weights = control_weights
347343 self .global_average_pooling = global_average_pooling
348344 self .extra_conds += ["y" ]
@@ -662,12 +658,15 @@ def load_controlnet(ckpt_path, model=None, model_options={}):
662658
663659class T2IAdapter (ControlBase ):
664660 def __init__ (self , t2i_model , channels_in , compression_ratio , upscale_algorithm , device = None ):
665- super ().__init__ (device )
661+ super ().__init__ ()
666662 self .t2i_model = t2i_model
667663 self .channels_in = channels_in
668664 self .control_input = None
669665 self .compression_ratio = compression_ratio
670666 self .upscale_algorithm = upscale_algorithm
667+ if device is None :
668+ device = comfy .model_management .get_torch_device ()
669+ self .device = device
671670
672671 def scale_image_to (self , width , height ):
673672 unshuffle_amount = self .t2i_model .unshuffle_amount
0 commit comments