Fix T5 export: encoder attention mask and static decoder shapes#230
Open
duyhv-qualgo wants to merge 1 commit into
Open
Fix T5 export: encoder attention mask and static decoder shapes#230duyhv-qualgo wants to merge 1 commit into
duyhv-qualgo wants to merge 1 commit into
Conversation
Three bugs in the T5 ExecuTorch export path in integrations.py: 1. `Seq2SeqLMEncoderExportableModule.forward` called the encoder without `attention_mask`, causing PAD tokens (id=0) to corrupt hidden states at real token positions. Fix: compute `attention_mask = (input_ids != 0)` internally and zero out PAD positions in the encoder output. 2. `Seq2SeqLMDecoderExportableModuleWithStaticCache.forward` did not pass `encoder_attention_mask` to the decoder. T5 cross-attention computes a relative position bias scaled by key_length; without the mask, key_length equals the full padded length (512) instead of the real encoder length, producing ~20x logit scale errors and wrong outputs. 3. `Seq2SeqLMExportableModule._export_decoder` marked `encoder_hidden_states` dim-1 as dynamic for T5. With transformers 5.0, T5's cross-attention slices a causal mask against the static KV-cache size, creating a symbolic-shape conflict during torch.export. Fix: use fully static shapes for T5 (callers pad encoder input to max_seq_len). Verified: ExecuTorch fp32 T5 output matches HuggingFace model.generate() exactly (5/5 test cases, exact string match) after these fixes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
Three bugs in
optimum/exporters/executorch/integrations.pythat cause wrong outputs or export failures when exporting T5 with ExecuTorch.Bug 1 — Encoder ignores padding tokens
Seq2SeqLMEncoderExportableModule.forwardcallsself.encoder(input_ids)without anattention_mask. When encoder inputs are padded to a fixed length, PAD tokens (id=0) attend to real positions and corrupt the hidden states at every token position — the encoder output is semantically wrong regardless of the input text.Fix: compute
attention_mask = (input_ids != 0).long()insideforwardand pass it to the encoder. Also zero out hidden states at PAD positions so the decoder cross-attention ignores them.Bug 2 — Decoder ignores encoder attention mask → wrong position bias
Seq2SeqLMDecoderExportableModuleWithStaticCache.forwardcallsself.decoder(...)withoutencoder_attention_mask. T5 computes a relative position bias in cross-attention scaled bykey_length. Without the mask,key_lengthequals the full padded length (e.g. 512) instead of the real token count, producing a ~20× logit scale error and completely wrong greedy-decoding output.Fix: add
encoder_attention_mask: Tensor | None = Nonetoforwardand pass it through toself.decoder(...).Bug 3 — Dynamic encoder dim conflicts with static KV cache (T5 export failure)
_export_decodermarksencoder_hidden_statesdim-1 as a dynamic symbol (encoder_hidden_seq_length). With transformers 5.0, T5's cross-attention slices a causal mask against the static KV-cache size:This raises
RuntimeError: tensor a (1024) must match tensor b (s96)duringtorch.export.Fix: use
dynamic_shapes=Nonefor T5 decoder export (fully static). Callers pad encoder inputs tomax_seq_lenbefore encoding —export()is updated to do this automatically.Changes
Seq2SeqLMEncoderExportableModule.forward: compute and applyattention_mask; zero PAD positions in output.Seq2SeqLMDecoderExportableModuleWithStaticCache.forward: addencoder_attention_maskparameter and pass to decoder.Seq2SeqLMExportableModule._export_decoder(T5 path): usedynamic_shapes=None; passencoder_attention_mask.Seq2SeqLMExportableModule.export: usemax_seq_len-padded encoder input for T5 to match static decoder shape; build and passencoder_attention_mask.Verification
Tested on a Helsinki-NLP–style T5 seq2seq (en↔vi) checkpoint,
max_seq_len=512, XNNPACK recipe:torch.exportsucceeds