88import logging
99
1010class T5XXLModel (sd1_clip .SDClipModel ):
11- def __init__ (self , device = "cpu" , layer = "last" , layer_idx = None , dtype = None , model_options = {}):
11+ def __init__ (self , device = "cpu" , layer = "last" , layer_idx = None , dtype = None , attention_mask = False , model_options = {}):
1212 textmodel_json_config = os .path .join (os .path .dirname (os .path .realpath (__file__ )), "t5_config_xxl.json" )
13- super ().__init__ (device = device , layer = layer , layer_idx = layer_idx , textmodel_json_config = textmodel_json_config , dtype = dtype , special_tokens = {"end" : 1 , "pad" : 0 }, model_class = comfy .text_encoders .t5 .T5 , model_options = model_options )
13+ super ().__init__ (device = device , layer = layer , layer_idx = layer_idx , textmodel_json_config = textmodel_json_config , dtype = dtype , special_tokens = {"end" : 1 , "pad" : 0 }, model_class = comfy .text_encoders .t5 .T5 , enable_attention_masks = attention_mask , return_attention_masks = attention_mask , model_options = model_options )
1414
1515class T5XXLTokenizer (sd1_clip .SDTokenizer ):
1616 def __init__ (self , embedding_directory = None , tokenizer_data = {}):
@@ -39,7 +39,7 @@ def state_dict(self):
3939 return {}
4040
4141class SD3ClipModel (torch .nn .Module ):
42- def __init__ (self , clip_l = True , clip_g = True , t5 = True , dtype_t5 = None , device = "cpu" , dtype = None , model_options = {}):
42+ def __init__ (self , clip_l = True , clip_g = True , t5 = True , dtype_t5 = None , t5_attention_mask = False , device = "cpu" , dtype = None , model_options = {}):
4343 super ().__init__ ()
4444 self .dtypes = set ()
4545 if clip_l :
@@ -57,7 +57,8 @@ def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu
5757
5858 if t5 :
5959 dtype_t5 = comfy .model_management .pick_weight_dtype (dtype_t5 , dtype , device )
60- self .t5xxl = T5XXLModel (device = device , dtype = dtype_t5 , model_options = model_options )
60+ self .t5_attention_mask = t5_attention_mask
61+ self .t5xxl = T5XXLModel (device = device , dtype = dtype_t5 , model_options = model_options , attention_mask = self .t5_attention_mask )
6162 self .dtypes .add (dtype_t5 )
6263 else :
6364 self .t5xxl = None
@@ -87,6 +88,7 @@ def encode_token_weights(self, token_weight_pairs):
8788 lg_out = None
8889 pooled = None
8990 out = None
91+ extra = {}
9092
9193 if len (token_weight_pairs_g ) > 0 or len (token_weight_pairs_l ) > 0 :
9294 if self .clip_l is not None :
@@ -111,7 +113,11 @@ def encode_token_weights(self, token_weight_pairs):
111113 pooled = torch .cat ((l_pooled , g_pooled ), dim = - 1 )
112114
113115 if self .t5xxl is not None :
114- t5_out , t5_pooled = self .t5xxl .encode_token_weights (token_weight_pairs_t5 )
116+ t5_output = self .t5xxl .encode_token_weights (token_weight_pairs_t5 )
117+ t5_out , t5_pooled = t5_output [:2 ]
118+ if self .t5_attention_mask :
119+ extra ["attention_mask" ] = t5_output [2 ]["attention_mask" ]
120+
115121 if lg_out is not None :
116122 out = torch .cat ([lg_out , t5_out ], dim = - 2 )
117123 else :
@@ -123,7 +129,7 @@ def encode_token_weights(self, token_weight_pairs):
123129 if pooled is None :
124130 pooled = torch .zeros ((1 , 768 + 1280 ), device = comfy .model_management .intermediate_device ())
125131
126- return out , pooled
132+ return out , pooled , extra
127133
128134 def load_sd (self , sd ):
129135 if "text_model.encoder.layers.30.mlp.fc1.weight" in sd :
@@ -133,8 +139,8 @@ def load_sd(self, sd):
133139 else :
134140 return self .t5xxl .load_sd (sd )
135141
136- def sd3_clip (clip_l = True , clip_g = True , t5 = True , dtype_t5 = None ):
142+ def sd3_clip (clip_l = True , clip_g = True , t5 = True , dtype_t5 = None , t5_attention_mask = False ):
137143 class SD3ClipModel_ (SD3ClipModel ):
138144 def __init__ (self , device = "cpu" , dtype = None , model_options = {}):
139- super ().__init__ (clip_l = clip_l , clip_g = clip_g , t5 = t5 , dtype_t5 = dtype_t5 , device = device , dtype = dtype , model_options = model_options )
145+ super ().__init__ (clip_l = clip_l , clip_g = clip_g , t5 = t5 , dtype_t5 = dtype_t5 , t5_attention_mask = t5_attention_mask , device = device , dtype = dtype , model_options = model_options )
140146 return SD3ClipModel_
0 commit comments