Skip to content

[WebGPU] Avoid indirect dispatch in FlashAttention decode to fix perf issues with Vulkan backend + GraphCapture/GraphReplay#28581

Open
hariharans29 wants to merge 2 commits into
mainfrom
hari/fa_indirect_dispatch_fix
Open

[WebGPU] Avoid indirect dispatch in FlashAttention decode to fix perf issues with Vulkan backend + GraphCapture/GraphReplay#28581
hariharans29 wants to merge 2 commits into
mainfrom
hari/fa_indirect_dispatch_fix

Conversation

@hariharans29
Copy link
Copy Markdown
Member

@hariharans29 hariharans29 commented May 20, 2026

Description

Dawn's TransformIndirectDispatchBuffer path on Vulkan costs a lot of overhead per call (vkAllocateMemory + vkAllocateDescriptorSets + bind group build). For Qwen3-1.7B, with 28 layers x 2 indirect kernels (FlashAttentionDecodeQKT,
FlashAttentionDecodeSplitVx), this adds quite a bit of overhead per decoded token with graph capture/replay, halving
graph-capture throughput on NV/Vulkan (measured on Windows with 5060Ti).

Replace SetIndirectDispatchTensor with a direct SetDispatchGroupSize using num_present_sequence_length_tile (worst-case constant, already in scope as a uniform). The shaders already mask over-dispatched workgroups via a
workgroup-uniform early-return on seqlens_k, so no shader change is needed and there is no divergence cost.

Motivation

Perf numbers after fix

Results @ prompt=1000, gen=500 (Qwen3-1.7B, NV/RTX/Windows)

Backend Baseline(Main) gc0 Baseline(Main) gc1 Fixed(PR) gc0 Fixed(PR) gc1 gc1 Δ
Vulkan 94.9 56.5 94.4 133.6 +136%
D3D12 67.3 111.4 72.6 138.1 +24%

…an+GraphCapture regression

Dawn's TransformIndirectDispatchBuffer path on Vulkan costs ~80us per call
(vkAllocateMemory + vkAllocateDescriptorSets + bind group build). With 28
layers x 2 indirect kernels (FlashAttentionDecodeQKT,
FlashAttentionDecodeSplitVx), this adds ~4.5ms per decoded token, halving
graph-capture throughput on NV/Vulkan.

Replace SetIndirectDispatchTensor with a direct SetDispatchGroupSize using
num_present_sequence_length_tile (worst-case constant, already in scope as a
uniform). The shaders already mask over-dispatched workgroups via a
workgroup-uniform early-return on seqlens_k, so no shader change is needed
and there is no divergence cost.

Validated on Qwen3-1.7B (prompt=3000, gen=500) on NV/RTX/Windows:
  Vulkan baseline:   gc0=90.7 tps  gc1=59.7 tps  (~34% regression)
  Vulkan with fix:   gc0=97.1 tps  gc1=122.4 tps (+105% on gc1)
  D3D12 baseline:    gc0=72.0 tps  gc1=114.6 tps
  D3D12 with fix:    gc0=67.6 tps  gc1=127.3 tps  (+11% on gc1)

D3D12 is largely unaffected because the per-dispatch indirect overhead
there is an order of magnitude smaller than on Vulkan.
@hariharans29 hariharans29 added the ep:WebGPU ort-web webgpu provider label May 20, 2026
@hariharans29 hariharans29 requested a review from Copilot May 20, 2026 06:49
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves WebGPU FlashAttention decode performance under graph capture/replay, primarily on Vulkan/Dawn, by avoiding Dawn’s costly indirect-dispatch transformation path and instead using a direct dispatch sized to a safe worst-case tile count (with existing shader-side masking).

Changes:

  • Replace indirect dispatch (SetIndirectDispatchTensor) with direct SetDispatchGroupSize in FlashAttentionDecodeQKT and FlashAttentionDecodeSplitVx.
  • Dispatch worst-case tiles when the (formerly) indirect-dispatch path is selected, relying on existing seqlens_k masking in shaders.
  • Mark indirect_buffer as unused in the updated decode kernels.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
hariharans29 added a commit that referenced this pull request May 20, 2026
- Drop the now-unused indirect dispatch buffer entirely: remove the
  GPU tensor allocation in ApplyFlashAttention, the indirect_buffer /
  prepare_indirect_dispatch plumbing from CopyKVCacheProgram and
  SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram (host + WGSL),
  and the indirect_buffer parameter from
  ComputeFlashAttentionDecodeQKT / ComputeFlashAttentionDecodeSplitVxScore.
  Both decode programs now use direct SetDispatchGroupSize, so the
  indirect buffer was being allocated, bound, and written by thread 0
  of the prologue but never read.
- Add a clarifying comment at both decode dispatch sites noting that,
  despite the legacy flag name, use_indirect_dispatch now selects
  worst-case tiling for a direct dispatch rather than issuing an
  indirect dispatch.

Pure cleanup. Validated on Qwen3-1.7B (Vulkan + D3D12 1k prompts,
graph capture on/off); no perf change vs. baseline.
- Drop the now-unused indirect dispatch buffer entirely: remove the
  GPU tensor allocation in ApplyFlashAttention, the indirect_buffer /
  prepare_indirect_dispatch plumbing from CopyKVCacheProgram and
  SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram (host + WGSL),
  and the indirect_buffer parameter from
  ComputeFlashAttentionDecodeQKT / ComputeFlashAttentionDecodeSplitVxScore.
  Both decode programs now use direct SetDispatchGroupSize, so the
  indirect buffer was being allocated, bound, and written by thread 0
  of the prologue but never read.
- Add a clarifying comment at both decode dispatch sites noting that,
  despite the legacy flag name, use_indirect_dispatch now selects
  worst-case tiling for a direct dispatch rather than issuing an
  indirect dispatch.

Pure cleanup. Validated on Qwen3-1.7B (Vulkan + D3D12 1k prompts,
graph capture on/off); no perf change vs. baseline.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants