Skip to content

Fix T5 export: encoder attention mask and static decoder shapes#230

Open
duyhv-qualgo wants to merge 1 commit into
huggingface:mainfrom
duyhv-qualgo:fix/t5-encoder-attention-mask-and-static-shapes
Open

Fix T5 export: encoder attention mask and static decoder shapes#230
duyhv-qualgo wants to merge 1 commit into
huggingface:mainfrom
duyhv-qualgo:fix/t5-encoder-attention-mask-and-static-shapes

Conversation

@duyhv-qualgo
Copy link
Copy Markdown

Problem

Three bugs in optimum/exporters/executorch/integrations.py that cause wrong outputs or export failures when exporting T5 with ExecuTorch.

Bug 1 — Encoder ignores padding tokens

Seq2SeqLMEncoderExportableModule.forward calls self.encoder(input_ids) without an attention_mask. When encoder inputs are padded to a fixed length, PAD tokens (id=0) attend to real positions and corrupt the hidden states at every token position — the encoder output is semantically wrong regardless of the input text.

Fix: compute attention_mask = (input_ids != 0).long() inside forward and pass it to the encoder. Also zero out hidden states at PAD positions so the decoder cross-attention ignores them.

Bug 2 — Decoder ignores encoder attention mask → wrong position bias

Seq2SeqLMDecoderExportableModuleWithStaticCache.forward calls self.decoder(...) without encoder_attention_mask. T5 computes a relative position bias in cross-attention scaled by key_length. Without the mask, key_length equals the full padded length (e.g. 512) instead of the real token count, producing a ~20× logit scale error and completely wrong greedy-decoding output.

Fix: add encoder_attention_mask: Tensor | None = None to forward and pass it through to self.decoder(...).

Bug 3 — Dynamic encoder dim conflicts with static KV cache (T5 export failure)

_export_decoder marks encoder_hidden_states dim-1 as a dynamic symbol (encoder_hidden_seq_length). With transformers 5.0, T5's cross-attention slices a causal mask against the static KV-cache size:

causal_mask = mask[:, :, :, : key_states.shape[-2]]  # static int at export time
position_bias = position_bias + causal_mask            # symbolic dim → shape conflict

This raises RuntimeError: tensor a (1024) must match tensor b (s96) during torch.export.

Fix: use dynamic_shapes=None for T5 decoder export (fully static). Callers pad encoder inputs to max_seq_len before encoding — export() is updated to do this automatically.

Changes

  • Seq2SeqLMEncoderExportableModule.forward: compute and apply attention_mask; zero PAD positions in output.
  • Seq2SeqLMDecoderExportableModuleWithStaticCache.forward: add encoder_attention_mask parameter and pass to decoder.
  • Seq2SeqLMExportableModule._export_decoder (T5 path): use dynamic_shapes=None; pass encoder_attention_mask.
  • Seq2SeqLMExportableModule.export: use max_seq_len-padded encoder input for T5 to match static decoder shape; build and pass encoder_attention_mask.

Verification

Tested on a Helsinki-NLP–style T5 seq2seq (en↔vi) checkpoint, max_seq_len=512, XNNPACK recipe:

Before After
Encoder output at PAD positions non-zero (corrupt) zeroed
ExecuTorch vs HF exact match 0/5 5/5
torch.export succeeds no (shape conflict) yes

Three bugs in the T5 ExecuTorch export path in integrations.py:

1. `Seq2SeqLMEncoderExportableModule.forward` called the encoder without
   `attention_mask`, causing PAD tokens (id=0) to corrupt hidden states at
   real token positions. Fix: compute `attention_mask = (input_ids != 0)`
   internally and zero out PAD positions in the encoder output.

2. `Seq2SeqLMDecoderExportableModuleWithStaticCache.forward` did not pass
   `encoder_attention_mask` to the decoder. T5 cross-attention computes a
   relative position bias scaled by key_length; without the mask, key_length
   equals the full padded length (512) instead of the real encoder length,
   producing ~20x logit scale errors and wrong outputs.

3. `Seq2SeqLMExportableModule._export_decoder` marked `encoder_hidden_states`
   dim-1 as dynamic for T5. With transformers 5.0, T5's cross-attention slices
   a causal mask against the static KV-cache size, creating a symbolic-shape
   conflict during torch.export. Fix: use fully static shapes for T5 (callers
   pad encoder input to max_seq_len).

Verified: ExecuTorch fp32 T5 output matches HuggingFace model.generate()
exactly (5/5 test cases, exact string match) after these fixes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants