[WebGPU] Avoid indirect dispatch in FlashAttention decode to fix perf issues with Vulkan backend + GraphCapture/GraphReplay#28581
Open
hariharans29 wants to merge 2 commits into
Open
Conversation
…an+GraphCapture regression Dawn's TransformIndirectDispatchBuffer path on Vulkan costs ~80us per call (vkAllocateMemory + vkAllocateDescriptorSets + bind group build). With 28 layers x 2 indirect kernels (FlashAttentionDecodeQKT, FlashAttentionDecodeSplitVx), this adds ~4.5ms per decoded token, halving graph-capture throughput on NV/Vulkan. Replace SetIndirectDispatchTensor with a direct SetDispatchGroupSize using num_present_sequence_length_tile (worst-case constant, already in scope as a uniform). The shaders already mask over-dispatched workgroups via a workgroup-uniform early-return on seqlens_k, so no shader change is needed and there is no divergence cost. Validated on Qwen3-1.7B (prompt=3000, gen=500) on NV/RTX/Windows: Vulkan baseline: gc0=90.7 tps gc1=59.7 tps (~34% regression) Vulkan with fix: gc0=97.1 tps gc1=122.4 tps (+105% on gc1) D3D12 baseline: gc0=72.0 tps gc1=114.6 tps D3D12 with fix: gc0=67.6 tps gc1=127.3 tps (+11% on gc1) D3D12 is largely unaffected because the per-dispatch indirect overhead there is an order of magnitude smaller than on Vulkan.
Contributor
There was a problem hiding this comment.
Pull request overview
This PR improves WebGPU FlashAttention decode performance under graph capture/replay, primarily on Vulkan/Dawn, by avoiding Dawn’s costly indirect-dispatch transformation path and instead using a direct dispatch sized to a safe worst-case tile count (with existing shader-side masking).
Changes:
- Replace indirect dispatch (
SetIndirectDispatchTensor) with directSetDispatchGroupSizeinFlashAttentionDecodeQKTandFlashAttentionDecodeSplitVx. - Dispatch worst-case tiles when the (formerly) indirect-dispatch path is selected, relying on existing
seqlens_kmasking in shaders. - Mark
indirect_bufferas unused in the updated decode kernels.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
hariharans29
added a commit
that referenced
this pull request
May 20, 2026
- Drop the now-unused indirect dispatch buffer entirely: remove the GPU tensor allocation in ApplyFlashAttention, the indirect_buffer / prepare_indirect_dispatch plumbing from CopyKVCacheProgram and SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram (host + WGSL), and the indirect_buffer parameter from ComputeFlashAttentionDecodeQKT / ComputeFlashAttentionDecodeSplitVxScore. Both decode programs now use direct SetDispatchGroupSize, so the indirect buffer was being allocated, bound, and written by thread 0 of the prologue but never read. - Add a clarifying comment at both decode dispatch sites noting that, despite the legacy flag name, use_indirect_dispatch now selects worst-case tiling for a direct dispatch rather than issuing an indirect dispatch. Pure cleanup. Validated on Qwen3-1.7B (Vulkan + D3D12 1k prompts, graph capture on/off); no perf change vs. baseline.
- Drop the now-unused indirect dispatch buffer entirely: remove the GPU tensor allocation in ApplyFlashAttention, the indirect_buffer / prepare_indirect_dispatch plumbing from CopyKVCacheProgram and SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram (host + WGSL), and the indirect_buffer parameter from ComputeFlashAttentionDecodeQKT / ComputeFlashAttentionDecodeSplitVxScore. Both decode programs now use direct SetDispatchGroupSize, so the indirect buffer was being allocated, bound, and written by thread 0 of the prologue but never read. - Add a clarifying comment at both decode dispatch sites noting that, despite the legacy flag name, use_indirect_dispatch now selects worst-case tiling for a direct dispatch rather than issuing an indirect dispatch. Pure cleanup. Validated on Qwen3-1.7B (Vulkan + D3D12 1k prompts, graph capture on/off); no perf change vs. baseline.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Dawn's
TransformIndirectDispatchBufferpath on Vulkan costs a lot of overhead per call (vkAllocateMemory + vkAllocateDescriptorSets + bind group build). For Qwen3-1.7B, with 28 layers x 2 indirect kernels (FlashAttentionDecodeQKT,FlashAttentionDecodeSplitVx), this adds quite a bit of overhead per decoded token with graph capture/replay, halving
graph-capture throughput on NV/Vulkan (measured on Windows with 5060Ti).
Replace
SetIndirectDispatchTensorwith a directSetDispatchGroupSizeusing num_present_sequence_length_tile (worst-case constant, already in scope as a uniform). The shaders already mask over-dispatched workgroups via aworkgroup-uniform early-return on seqlens_k, so no shader change is needed and there is no divergence cost.
Motivation
Perf numbers after fix
Results @ prompt=1000, gen=500 (Qwen3-1.7B, NV/RTX/Windows)