Skip to content

Commit 1b80895

Browse files
Make clip loader nodes support loading sd3 t5xxl in lower precision.
Add attention mask support in the SD3 text encoder code.
1 parent 5f9d5a2 commit 1b80895

2 files changed

Lines changed: 31 additions & 20 deletions

File tree

comfy/sd.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,19 @@ def detect_te_model(sd):
431431
return TEModel.T5_BASE
432432
return None
433433

434+
435+
def t5xxl_weight_dtype(clip_data):
436+
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
437+
438+
dtype_t5 = None
439+
for sd in clip_data:
440+
weight = sd.get(weight_name, None)
441+
if weight is not None:
442+
dtype_t5 = weight.dtype
443+
break
444+
return dtype_t5
445+
446+
434447
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
435448
clip_data = state_dicts
436449

@@ -462,9 +475,7 @@ class EmptyClass:
462475
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
463476
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
464477
elif te_model == TEModel.T5_XXL:
465-
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
466-
dtype_t5 = weight.dtype
467-
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
478+
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=t5xxl_weight_dtype(clip_data))
468479
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
469480
elif te_model == TEModel.T5_XL:
470481
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
@@ -482,25 +493,19 @@ class EmptyClass:
482493
elif len(clip_data) == 2:
483494
if clip_type == CLIPType.SD3:
484495
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
485-
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models)
496+
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models, dtype_t5=t5xxl_weight_dtype(clip_data))
486497
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
487498
elif clip_type == CLIPType.HUNYUAN_DIT:
488499
clip_target.clip = comfy.text_encoders.hydit.HyditModel
489500
clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer
490501
elif clip_type == CLIPType.FLUX:
491-
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
492-
weight = clip_data[0].get(weight_name, clip_data[1].get(weight_name, None))
493-
dtype_t5 = None
494-
if weight is not None:
495-
dtype_t5 = weight.dtype
496-
497-
clip_target.clip = comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5)
502+
clip_target.clip = comfy.text_encoders.flux.flux_clip(dtype_t5=t5xxl_weight_dtype(clip_data))
498503
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
499504
else:
500505
clip_target.clip = sdxl_clip.SDXLClipModel
501506
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
502507
elif len(clip_data) == 3:
503-
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
508+
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(dtype_t5=t5xxl_weight_dtype(clip_data))
504509
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
505510

506511
parameters = 0

comfy/text_encoders/sd3_clip.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import logging
99

1010
class T5XXLModel(sd1_clip.SDClipModel):
11-
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
11+
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
1212
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
13-
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
13+
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
1414

1515
class T5XXLTokenizer(sd1_clip.SDTokenizer):
1616
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -39,7 +39,7 @@ def state_dict(self):
3939
return {}
4040

4141
class SD3ClipModel(torch.nn.Module):
42-
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None, model_options={}):
42+
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options={}):
4343
super().__init__()
4444
self.dtypes = set()
4545
if clip_l:
@@ -57,7 +57,8 @@ def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu
5757

5858
if t5:
5959
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
60-
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
60+
self.t5_attention_mask = t5_attention_mask
61+
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=self.t5_attention_mask)
6162
self.dtypes.add(dtype_t5)
6263
else:
6364
self.t5xxl = None
@@ -87,6 +88,7 @@ def encode_token_weights(self, token_weight_pairs):
8788
lg_out = None
8889
pooled = None
8990
out = None
91+
extra = {}
9092

9193
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
9294
if self.clip_l is not None:
@@ -111,7 +113,11 @@ def encode_token_weights(self, token_weight_pairs):
111113
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
112114

113115
if self.t5xxl is not None:
114-
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
116+
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
117+
t5_out, t5_pooled = t5_output[:2]
118+
if self.t5_attention_mask:
119+
extra["attention_mask"] = t5_output[2]["attention_mask"]
120+
115121
if lg_out is not None:
116122
out = torch.cat([lg_out, t5_out], dim=-2)
117123
else:
@@ -123,7 +129,7 @@ def encode_token_weights(self, token_weight_pairs):
123129
if pooled is None:
124130
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
125131

126-
return out, pooled
132+
return out, pooled, extra
127133

128134
def load_sd(self, sd):
129135
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -133,8 +139,8 @@ def load_sd(self, sd):
133139
else:
134140
return self.t5xxl.load_sd(sd)
135141

136-
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
142+
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False):
137143
class SD3ClipModel_(SD3ClipModel):
138144
def __init__(self, device="cpu", dtype=None, model_options={}):
139-
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
145+
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
140146
return SD3ClipModel_

0 commit comments

Comments
 (0)