Skip to content

Support FA2 flash_attn_with_kvcache for XPU continuous batching#46028

Merged
vasqu merged 6 commits into
huggingface:mainfrom
YangKai0616:FA2kvcache-XPU
May 27, 2026
Merged

Support FA2 flash_attn_with_kvcache for XPU continuous batching#46028
vasqu merged 6 commits into
huggingface:mainfrom
YangKai0616:FA2kvcache-XPU

Conversation

@YangKai0616
Copy link
Copy Markdown
Contributor

What does this PR do?

XPU now supports flash_attn_with_kvcache in kernels-community/flash-attn2. This PR enables the relevant path in CB. transformers pipeline test results:

| label                  |   samples |   avg_in |   max_new |   time (s) |   tokens |   tok/s |   mem (GB) |
|------------------------|-----------|----------|-----------|------------|----------|---------|------------|
| fast_decode_off_varlen |         2 |      512 |      1024 |      59.91 |     2048 |   34.18 |       9.09 |
| fast_decode_on_kvcache |         2 |      512 |      1024 |      44.9  |     2048 |   45.61 |       9.09 |
fast_decode_on_kvcache speedup over fast_decode_off_varlen: 1.33x
|------------------------|-----------|----------|-----------|------------|----------|---------|------------|
| fast_decode_off_varlen |         8 |      512 |      1024 |      60.76 |     8192 |  134.82 |       9.09 |
| fast_decode_on_kvcache |         8 |      512 |      1024 |      35.11 |     8192 |  233.33 |       9.09 |
fast_decode_on_kvcache speedup over fast_decode_off_varlen: 1.73x
| fast_decode_off_varlen |         8 |     2048 |       512 |      31.04 |     4096 |  131.94 |       9.09 |
| fast_decode_on_kvcache |         8 |     2048 |       512 |      22.39 |     4096 |  182.92 |       9.09 |
fast_decode_on_kvcache speedup over fast_decode_off_varlen: 1.39x

Model: Qwen/Qwen3-0.6B.
Device: B60.

Hi @ArthurZucker , pls help review, thanks!

@YangKai0616
Copy link
Copy Markdown
Contributor Author

Hi @vasqu , could you help review this PR related to attention, thank you!

@YangKai0616
Copy link
Copy Markdown
Contributor Author

Friendly ping @IlyasMoutawwakil .

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Careful approval, just a nit to maybe combine the branches into one but cc @remi-or if you could take a look (but should be safe)

Comment thread src/transformers/generation/continuous_batching/initialization.py
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +107 to +109
# NOTE: For CUDA, block table should be available with FA2 and FA3, but there seems to be an issue with FA2 atm
cuda_available = torch.cuda.is_available()
fa_cuda = is_flash_attention_requested(config, version=3) and cuda_available
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you check if this is still an issue with fa2 ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested flash_attention_2 on A100 and it works normally. However, paged|kernels-community/flash-attn2 fails at the kvcache kernel shape check. The current PR can first ensure XPU access to the CB pipeline.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, that's interesting. What version of FA2 did you use? Usually the kernels version shouldnt diverge from upstream

But yea anyways, this feels out of scope

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, that's interesting. What version of FA2 did you use? Usually the kernels version shouldnt diverge from upstream

flash_attn 2.8.3

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also check the beta versions? I think they indirectly included some changes as well to FA2

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we open a separate issue to track this? Would merge this PR after then and we can track this separately

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary:

Test script is in repo https://github.com/YangKai0616/transformers/tree/tmp-fa-test, test command: pytest -ra tests/generation/test_continuous_batching.py::ContinuousBatchingWithAcceleratorTest -k test_flash_attn2_with_kvcache_parity.

Test Results:

flash_attn 2.8.3: passed
flash_attn 2.8.4 (pip install git+https://github.com/Dao-AILab/flash-attention.git@main): inconsistent output text
kernels-community/flash-attn2: triggered internal out CHECK_SHAPE error

kernels 0.14.1, and it should be unrelated to kernels, it's an internal issue with FA.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for surfacing this. huggingface/kernels-community#877 -- let's wait for this one to get merged and reconvene.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, can we merge the current PR for XPU first? Thanks!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened an issue over here huggingface/kernels-community#894 - I think we will try to sync with the latest FA2 release instead of the main branch. As you also showed it's not stable

Merging this PR, thanks for the discussion! Very valuable imo

@vasqu vasqu requested a review from remi-or May 26, 2026 13:44
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented May 26, 2026

Only waiting for Remi, but LGTM overall!

@remi-or
Copy link
Copy Markdown
Collaborator

remi-or commented May 26, 2026

Thanks for the ping, LGTM!

@vasqu vasqu enabled auto-merge May 27, 2026 15:59
@vasqu vasqu added this pull request to the merge queue May 27, 2026
Merged via the queue into huggingface:main with commit a65bf6c May 27, 2026
30 checks passed
yuchenxie4645 pushed a commit to yuchenxie4645/transformers that referenced this pull request May 28, 2026
…ingface#46028)

* Support FA2 flash_attn_with_kvcache for XPU continuous batching

* Update according to the comments

* simplify the code

* Code quality
kashif pushed a commit to kashif/transformers that referenced this pull request Jun 1, 2026
…ingface#46028)

* Support FA2 flash_attn_with_kvcache for XPU continuous batching

* Update according to the comments

* simplify the code

* Code quality
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants