TurboQuant KV cache (3/4): WebGPU kernels + Safari/Firefox fallback#28562
Draft
TimPietrusky wants to merge 2 commits into
Draft
TurboQuant KV cache (3/4): WebGPU kernels + Safari/Firefox fallback#28562TimPietrusky wants to merge 2 commits into
TimPietrusky wants to merge 2 commits into
Conversation
Adds a `TurboQuantKVFusion` graph transformer that rewrites every
GroupQueryAttention node at session-create time to use a TurboQuant
4-bit packed KV cache, plus the schema, session-option keys, and CPU
helpers required for that rewrite. No kernels in this PR — they
land in follow-ups for CUDA and WebGPU.
What this PR includes:
* `core/optimizer/turboquant_kv_fusion.{cc,h}` — the L2 transformer.
Enabled by setting `optimization.turboquant_kv_method` to one of
`turboquant_4bit_nc`, `turboquant_k3v4_nc`, `turboquant_3bit_nc`.
Runs on CUDA + WebGPU EPs. Computes Lloyd-Max centroids for the
given (head_dim, key_bits) and a normalised Walsh–Hadamard matrix,
injects both as graph initializers, and mutates each GQA node's
attributes + past/present tensor types to (uint8, slot_bytes).
* `core/graph/contrib_ops/bert_defs.cc` — extends GroupQueryAttention
with the new attributes (`kv_quant_method`, `key_quant_bits`,
`value_quant_bits`, `norm_correction`) and two new optional inputs
at slots 14 / 15 for the shared k_codebook + hadamard initializers.
* `include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h`
— public option keys `optimization.turboquant_kv_method` and
`optimization.turboquant_kv_boundary`.
* `contrib_ops/cpu/bert/attention_common.h` + `attention_parameters.h`
+ `group_query_attention_helper.h` — `KVQuantMethod` enum, parameter
struct extensions, and `CheckInputs` updates so the fp16 codepath
passes through unchanged when TurboQuant isn't requested.
* `include/onnxruntime/core/framework/int3.h` — new packed `UInt3x8`
type for 3-bit cache slots. Used by the (forthcoming) 3-bit
variants.
* `test/contrib_ops/turboquant_kv_test.cc` — host-side bit-layout
tests for `UInt3x8`. Kernel-level correctness is validated by the
follow-up CUDA / WebGPU PRs.
When `optimization.turboquant_kv_method` is unset or set to "none" /
"off" the transformer doesn't fire and the graph is byte-identical
to today's output.
Design doc + reference NumPy implementation + paper-validation tests
are coming in the Python tooling PR. The CUDA kernels (16-bit accum
WMMA + 4-bit packed cache) and the WebGPU kernels (WGSL encode/decode
with an ApplyAttention fallback for browsers without Subgroups) come
in separate PRs that each depend on this one.
Benches (LFM2.5-1.2B, RTX A40, all measured):
ctx fp16 decode TQ decode speedup
4 K 6.2 s reply 6.0 s reply tied
32 K 26 s 24 s 7 %
64 K 63 s 41 s 53 %
128 K (fp16 OOM) 65 s TQ only
WebGPU EP TurboQuant kernels referenced by the graph rewriter in microsoft#28560. Same packed cache layout and same algorithm as the CUDA PR (microsoft#28561), implemented as two WGSL programs + a C++ orchestrator that mirrors CUDA's Option-ε pattern. Same code runs in three places: 1. Native: `--build_wheel --use_webgpu` (Python wheel, Dawn → Metal / D3D12 / Vulkan). 2. Browser: `--build_wasm --use_webgpu --enable_wasm_jspi` + the JS layer built with `--webgpu-ep --jspi` — onnxruntime-web inside transformers.js. 3. Node.js: `--build_nodejs --use_webgpu` (NAPI binding, same Dawn). What this PR contains: * `contrib_ops/webgpu/bert/turboquant_encode.wgsl.template` — packs fresh fp16 K/V into the uint8 slot layout in a single pass per (batch, kv_head, present_slot). When `s_cache < past_seq_len` the shader byte-copies from past_key / past_value; otherwise it runs the full encode (||k||, FWHT, Lloyd-Max codebook lookup, V min/max, bit-pack). * `contrib_ops/webgpu/bert/turboquant_decode.wgsl.template` — packed cache → fp16 K/V scratch in BNSH layout. Codebook gather, optional norm correction, vec_norm scale, inverse FWHT, V dequant. * `contrib_ops/webgpu/bert/turboquant_attention.{cc,h}` — the `RunTurboQuantAttention` orchestrator. Aliases the uint8 cache as uint32 view tensors so WGSL `array<u32>` bindings work (ORT's ShaderVariableHelper rejects uint8 storage bindings). Branches on `WebGPU::Subgroups` feature: when present, post-decode attention goes through `ApplyFlashAttention` (Q stays BSNH, K/V scratch declared as BNSH via `Q_K_V_BSNH_BNSH_BNSH`); when absent (Safari, Firefox, older Chrome) we transfer Q to BNSH and route through `ApplyAttention` instead. An env-var escape hatch (`ORT_TQ_DISABLE_FA=1`) forces the fallback on subgroups-capable adapters for dev-machine testing. * `contrib_ops/webgpu/bert/group_query_attention.cc/.h` — TQ-aware dispatch in the WebGPU GQA op. Recognises the new attributes from microsoft#28560 (`kv_quant_method`, `key_quant_bits`, `value_quant_bits`, `norm_correction`), reads codebook + Hadamard from input slots 14/15, runs rotary inline before the orchestrator (since `ApplyFlashAttention`'s do_rotary path requires packed QKV which we don't have), and computes `past_sequence_length` from `past_key.shape[2]` rather than `seqlens_k` (the HF causal-LM exports often leave seqlens_k zero-filled on the prompt step; the past tensor's shape is the ground truth either way and matches what `CheckInputs` does for the fp16 path). Verified Apple Silicon Metal numbers (M-series, in-browser ORT-web via WASM/JSPI; LFM2.5-1.2B-Instruct-ONNX): ctx fp16 decode TQ decode decode speedup 4 K 91.3 ms/tok 74.1 ms/tok 1.23x 16 K 179.1 ms/tok 84.0 ms/tok 2.13x Same model in onnxruntime-node on the same machine: 4 K 28.3 ms/tok 20.0 ms/tok 1.41x Quality: cosine-sim 0.993 - 1.000 vs fp16 across a 9-step decode chain. Top-1 token match 7-of-9 (= matches CUDA in microsoft#28561). Fallback path (no subgroups, `ORT_TQ_DISABLE_FA=1`): Qwen3.5 4K, Node.js — TQ decode 14.3 ms/tok = 70 tok/s (vs 28 ms/tok = 36 tok/s on FA path). Faster on this shape because ApplyAttention's split-Vx decode kernel is more cache-friendly for (small Q, large cached K). Prompt step is slower (~25%) because of the BNSH transfer overhead. Cross-OS build (Linux + Windows) green on GHA — both compile clean. Things worth a careful look: 1. **Subgroups fallback path** — `ApplyAttention` is the existing non-flash GQA path. Q/K/V layouts have to line up with what it expects (BNSH for K/V, BSNH for Q). K/V scratch comes out of the decode shader already in BNSH, so only Q needs `TransferBSDToBNSH`. I set `past_sequence_length = 0` and feed the entire present cache as `K` to skip ApplyAttention's internal past-merge. 2. **uint8 → uint32 alias view** in turboquant_attention.cc — the underlying buffer never changes layout, only the type ORT sees. `slot_bytes` is always a multiple of 4 (we pad). 3. **`detail::GetEnvironmentVar`** for the escape hatch — using `std::getenv` breaks MSVC's `-Werror` for C4996. Depends on microsoft#28560. Does not intersect with microsoft#28561 (CUDA) or the Python tooling PR.
Contributor
|
@TimPietrusky please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.
Contributor License AgreementContribution License AgreementThis Contribution License Agreement (“Agreement”) is agreed to by the party signing below (“You”),
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
WebGPU EP kernels for the TurboQuant 4-bit KV cache path defined by #28560. Same algorithm and packed-cache layout as the CUDA PR (#28561), implemented as two WGSL programs and a C++ orchestrator that mirrors CUDA's Option-ε pattern.
Same code runs in three places:
--build_wheel --use_webgpu) — Python wheel, Dawn → Metal/D3D12/Vulkan.--build_wasm --use_webgpu --enable_wasm_jspi+ JS layer--webgpu-ep --jspi) — onnxruntime-web inside transformers.js.--build_nodejs --use_webgpu) — NAPI binding, same Dawn.Verified numbers (Apple Silicon Metal, in-browser ORT-web via WASM/JSPI, LFM2.5-1.2B)
Same model in onnxruntime-node on the same machine: 28.3 / 20.0 ms/tok (1.41×). Quality: per-step cosine sim 0.993–1.000 vs fp16, top-1 token match 7-of-9 on a 9-step decode chain (matches CUDA in #28561).
What this PR contains
contrib_ops/webgpu/bert/turboquant_encode.wgsl.template— packs fresh fp16 K/V into the uint8 slot layout in one pass per(batch, kv_head, present_slot). Whens_cache < past_seq_lenthe shader byte-copies from past_key/past_value; otherwise it does the full encode (||k||, FWHT, Lloyd-Max codebook lookup, V min/max, bit-pack).contrib_ops/webgpu/bert/turboquant_decode.wgsl.template— packed cache → fp16 K/V scratch in BNSH layout. Codebook gather, optional norm correction, vec_norm scale, inverse FWHT, V dequant.contrib_ops/webgpu/bert/turboquant_attention.{cc,h}— theRunTurboQuantAttentionorchestrator. Aliases the uint8 cache as uint32 view tensors so WGSLarray<u32>bindings work (ORT'sShaderVariableHelperrejects uint8 storage bindings). Branches onWebGPU::Subgroupsfeature:ApplyFlashAttentionwith Q in BSNH, K/V scratch declared BNSH viaQ_K_V_BSNH_BNSH_BNSH.ApplyAttentioninstead.ORT_TQ_DISABLE_FA=1forces the fallback on subgroups-capable adapters for dev-machine testing.contrib_ops/webgpu/bert/group_query_attention.cc/.h— TQ-aware dispatch. Reads the new GQA attributes from TurboQuant KV cache (1/4): graph rewrite + schema (foundation) #28560, reads codebook + Hadamard from input slots 14/15, runs rotary inline before the orchestrator (sinceApplyFlashAttention's do_rotary path requires packed QKV which we don't have), and computespast_sequence_lengthfrompast_key.shape[2]rather thanseqlens_k(HF causal-LM exports often leave seqlens_k zero-filled on the prompt step; past_key's shape is the ground truth either way, mirroring whatCheckInputsdoes on the fp16 path).Fallback numbers (no-subgroups path, Qwen3.5-0.8B-Text @ 4K, Node.js)
Decode is actually faster on the fallback for this shape —
ApplyAttention's split-Vx decode kernel is more cache-friendly for(small Q, large cached K)than FA on our scratch tensors. Prompt is ~25% slower because of the extraTransferBSDToBNSH. Both produce coherent output (no shape errors, no NaN).Things worth a careful look
ApplyAttentionexpects (BNSH for K/V, BSNH for Q). K/V scratch comes out of the decode shader already BNSH; only Q needsTransferBSDToBNSH. I setpast_sequence_length = 0and feed the entire present cache asKto skip ApplyAttention's internal past-merge.slot_bytesis always a multiple of 4 (we pad).detail::GetEnvironmentVarfor the escape hatch —std::getenvbreaks MSVC's-Werror(C4996).Cross-OS
Cross-OS build CI green on Linux + Windows (build/link-only since hosted runners have no GPU). Apple Silicon Metal verified end-to-end on local Mac + a clean GHA
macos-15runner.Depends on
Does not intersect with