Support FA2 flash_attn_with_kvcache for XPU continuous batching#46028
Conversation
|
Hi @vasqu , could you help review this PR related to attention, thank you! |
|
Friendly ping @IlyasMoutawwakil . |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| # NOTE: For CUDA, block table should be available with FA2 and FA3, but there seems to be an issue with FA2 atm | ||
| cuda_available = torch.cuda.is_available() | ||
| fa_cuda = is_flash_attention_requested(config, version=3) and cuda_available |
There was a problem hiding this comment.
did you check if this is still an issue with fa2 ?
There was a problem hiding this comment.
Tested flash_attention_2 on A100 and it works normally. However, paged|kernels-community/flash-attn2 fails at the kvcache kernel shape check. The current PR can first ensure XPU access to the CB pipeline.
There was a problem hiding this comment.
Hmm, that's interesting. What version of FA2 did you use? Usually the kernels version shouldnt diverge from upstream
But yea anyways, this feels out of scope
There was a problem hiding this comment.
Hmm, that's interesting. What version of FA2 did you use? Usually the kernels version shouldnt diverge from upstream
flash_attn 2.8.3
There was a problem hiding this comment.
Could you also check the beta versions? I think they indirectly included some changes as well to FA2
There was a problem hiding this comment.
Can we open a separate issue to track this? Would merge this PR after then and we can track this separately
There was a problem hiding this comment.
Summary:
Test script is in repo https://github.com/YangKai0616/transformers/tree/tmp-fa-test, test command: pytest -ra tests/generation/test_continuous_batching.py::ContinuousBatchingWithAcceleratorTest -k test_flash_attn2_with_kvcache_parity.
Test Results:
flash_attn 2.8.3: passed
flash_attn 2.8.4 (pip install git+https://github.com/Dao-AILab/flash-attention.git@main): inconsistent output text
kernels-community/flash-attn2: triggered internal out CHECK_SHAPE error
kernels 0.14.1, and it should be unrelated to kernels, it's an internal issue with FA.
There was a problem hiding this comment.
Thanks for surfacing this. huggingface/kernels-community#877 -- let's wait for this one to get merged and reconvene.
There was a problem hiding this comment.
Sure, can we merge the current PR for XPU first? Thanks!
There was a problem hiding this comment.
Opened an issue over here huggingface/kernels-community#894 - I think we will try to sync with the latest FA2 release instead of the main branch. As you also showed it's not stable
Merging this PR, thanks for the discussion! Very valuable imo
|
Only waiting for Remi, but LGTM overall! |
|
Thanks for the ping, LGTM! |
…ingface#46028) * Support FA2 flash_attn_with_kvcache for XPU continuous batching * Update according to the comments * simplify the code * Code quality
…ingface#46028) * Support FA2 flash_attn_with_kvcache for XPU continuous batching * Update according to the comments * simplify the code * Code quality
What does this PR do?
XPU now supports
flash_attn_with_kvcacheinkernels-community/flash-attn2. This PR enables the relevant path in CB.transformerspipeline test results:Model:
Qwen/Qwen3-0.6B.Device:
B60.Hi @ArthurZucker , pls help review, thanks!