Skip to content

Commit 754597c

Browse files
Clean up some controlnet code.
Remove self.device which was useless.
1 parent 915fdb5 commit 754597c

1 file changed

Lines changed: 10 additions & 11 deletions

File tree

comfy/controlnet.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class StrengthType(Enum):
6060
LINEAR_UP = 2
6161

6262
class ControlBase:
63-
def __init__(self, device=None):
63+
def __init__(self):
6464
self.cond_hint_original = None
6565
self.cond_hint = None
6666
self.strength = 1.0
@@ -72,10 +72,6 @@ def __init__(self, device=None):
7272
self.compression_ratio = 8
7373
self.upscale_algorithm = 'nearest-exact'
7474
self.extra_args = {}
75-
76-
if device is None:
77-
device = comfy.model_management.get_torch_device()
78-
self.device = device
7975
self.previous_controlnet = None
8076
self.extra_conds = []
8177
self.strength_type = StrengthType.CONSTANT
@@ -185,8 +181,8 @@ def set_extra_arg(self, argument, value=None):
185181

186182

187183
class ControlNet(ControlBase):
188-
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
189-
super().__init__(device)
184+
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
185+
super().__init__()
190186
self.control_model = control_model
191187
self.load_device = load_device
192188
if control_model is not None:
@@ -242,7 +238,7 @@ def get_control(self, x_noisy, t, cond, batched_number):
242238
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
243239
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
244240

245-
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
241+
self.cond_hint = self.cond_hint.to(device=self.load_device, dtype=dtype)
246242
if x_noisy.shape[0] != self.cond_hint.shape[0]:
247243
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
248244

@@ -341,8 +337,8 @@ def forward(self, input):
341337

342338

343339
class ControlLora(ControlNet):
344-
def __init__(self, control_weights, global_average_pooling=False, device=None, model_options={}): #TODO? model_options
345-
ControlBase.__init__(self, device)
340+
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
341+
ControlBase.__init__(self)
346342
self.control_weights = control_weights
347343
self.global_average_pooling = global_average_pooling
348344
self.extra_conds += ["y"]
@@ -662,12 +658,15 @@ def load_controlnet(ckpt_path, model=None, model_options={}):
662658

663659
class T2IAdapter(ControlBase):
664660
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
665-
super().__init__(device)
661+
super().__init__()
666662
self.t2i_model = t2i_model
667663
self.channels_in = channels_in
668664
self.control_input = None
669665
self.compression_ratio = compression_ratio
670666
self.upscale_algorithm = upscale_algorithm
667+
if device is None:
668+
device = comfy.model_management.get_torch_device()
669+
self.device = device
671670

672671
def scale_image_to(self, width, height):
673672
unshuffle_amount = self.t2i_model.unshuffle_amount

0 commit comments

Comments
 (0)