Skip to content

Commit 478628d

Browse files
committed
fix(attention): added not supported assertion for inter document masking with pytorch sdpa
1 parent 382a952 commit 478628d

1 file changed

Lines changed: 3 additions & 0 deletions

File tree

src/modalities/models/gpt2/gpt2_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,9 @@ def execute_attention(
825825
) # (B, nh_q, T, hd)
826826
y = y.transpose(1, 2).contiguous() # (B, T, nh_q, hd)
827827
elif attention_impl == AttentionImplementation.PYTORCH_FLASH:
828+
assert (
829+
attention_masking_information is None
830+
), "Inter-document masking is not supported for PyTorch Flash Attention."
828831
k, v = cls.repeat_kv_heads(q, k, v) # for GQA (group query attention)
829832
y = torch.nn.functional.scaled_dot_product_attention(
830833
query=q,

0 commit comments

Comments
 (0)