TurboQuant KV cache (2/4): CUDA kernels#28561
Draft
TimPietrusky wants to merge 2 commits into
Draft
Conversation
Adds a `TurboQuantKVFusion` graph transformer that rewrites every
GroupQueryAttention node at session-create time to use a TurboQuant
4-bit packed KV cache, plus the schema, session-option keys, and CPU
helpers required for that rewrite. No kernels in this PR — they
land in follow-ups for CUDA and WebGPU.
What this PR includes:
* `core/optimizer/turboquant_kv_fusion.{cc,h}` — the L2 transformer.
Enabled by setting `optimization.turboquant_kv_method` to one of
`turboquant_4bit_nc`, `turboquant_k3v4_nc`, `turboquant_3bit_nc`.
Runs on CUDA + WebGPU EPs. Computes Lloyd-Max centroids for the
given (head_dim, key_bits) and a normalised Walsh–Hadamard matrix,
injects both as graph initializers, and mutates each GQA node's
attributes + past/present tensor types to (uint8, slot_bytes).
* `core/graph/contrib_ops/bert_defs.cc` — extends GroupQueryAttention
with the new attributes (`kv_quant_method`, `key_quant_bits`,
`value_quant_bits`, `norm_correction`) and two new optional inputs
at slots 14 / 15 for the shared k_codebook + hadamard initializers.
* `include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h`
— public option keys `optimization.turboquant_kv_method` and
`optimization.turboquant_kv_boundary`.
* `contrib_ops/cpu/bert/attention_common.h` + `attention_parameters.h`
+ `group_query_attention_helper.h` — `KVQuantMethod` enum, parameter
struct extensions, and `CheckInputs` updates so the fp16 codepath
passes through unchanged when TurboQuant isn't requested.
* `include/onnxruntime/core/framework/int3.h` — new packed `UInt3x8`
type for 3-bit cache slots. Used by the (forthcoming) 3-bit
variants.
* `test/contrib_ops/turboquant_kv_test.cc` — host-side bit-layout
tests for `UInt3x8`. Kernel-level correctness is validated by the
follow-up CUDA / WebGPU PRs.
When `optimization.turboquant_kv_method` is unset or set to "none" /
"off" the transformer doesn't fire and the graph is byte-identical
to today's output.
Design doc + reference NumPy implementation + paper-validation tests
are coming in the Python tooling PR. The CUDA kernels (16-bit accum
WMMA + 4-bit packed cache) and the WebGPU kernels (WGSL encode/decode
with an ApplyAttention fallback for browsers without Subgroups) come
in separate PRs that each depend on this one.
Benches (LFM2.5-1.2B, RTX A40, all measured):
ctx fp16 decode TQ decode speedup
4 K 6.2 s reply 6.0 s reply tied
32 K 26 s 24 s 7 %
64 K 63 s 41 s 53 %
128 K (fp16 OOM) 65 s TQ only
Adds the CUDA TurboQuant attention kernels referenced by the graph rewriter in microsoft#28560. All speedups numbers are end-to-end against fp16 baseline on RTX A40, real models pulled from HuggingFace: Model ctx fp16 reply TQ reply speedup LFM2.5-1.2B 4 K 6.2 s 6.0 s tied LFM2.5-1.2B 32 K 26.0 s 24.1 s 7 % LFM2.5-1.2B 64 K 63.0 s 41.1 s 53 % LFM2.5-1.2B 128 K (fp16 OOM) 65 s TQ-only Qwen3-0.6B 4 K 26.6 s 17.6 s 51 % Qwen3-0.6B 16 K 93 s 58 s 60 % Qwen3-0.6B 32 K 187 s 96 s 94 % KV cache 3.56× smaller (LFM2.5, hd=64) / 3.76× (Qwen3-0.6B, hd=128). Per-layer cosine sim vs fp16: 0.99526 across-context, top-1 token match 7/9 on a 9-step decode chain (matches CUDA TurboQuant kernels in vLLM bit-exact in the encode path; the decode kernels have minor differences in tile scheduling). What this PR contains: * `contrib_ops/cuda/bert/group_query_attention_turboquant.cuh` — templated CUDA kernels: - `TQEncodeKernel` for fresh-K/V pack (Walsh-Hadamard + Lloyd-Max). - `TQFlashAttentionKernel` (v4-lite) for decode-step fused FA on packed K with online softmax. - `TQFlashAttentionQTiledKernel` (v5) for continuation-prefill, ~2× prompt speedup over v4-lite at 4K context. - `TQFlashAttentionWmmaKernel` (v6) for tensor-core Q*K^T scoring at hd=64. Falls through to v5 for other head_dims. * `contrib_ops/cuda/bert/group_query_attention_turboquant_impl.{cu,h}` — launch shims, vectorised uint4 K/V smem loads (8-fp16 per HBM transaction), Option-ε dispatch that routes the prompt step to the stock FlashAttention kernel when past_seq=0 (bit-equivalent to fp16 for the first call so model load is byte-stable). * `contrib_ops/cuda/bert/group_query_attention.cc/.h` — TQ-aware dispatch in the existing GroupQueryAttention CUDA op. When `kv_quant_method != KVQuantMethod::None` we read the new attributes from microsoft#28560 and route to the kernels above. When it's None the op is byte-identical to today. * `contrib_ops/cuda/bert/attention_data.h` — three new fields in `GroupQueryAttentionData` for the codebook / Hadamard pointers and quant-method enum. Unused on the fp16 path. Things I'd appreciate a sharp eye on: 1. **WMMA tile schedule** in v6 — `kBlockK` is currently coupled to `kHeadDim`. Decoupling unlocks hd=128 (Qwen3-0.6B) at the same speed as hd=64 (LFM2.5). I have a follow-up patch ready but kept it out of this PR to minimise diff size. 2. **cp.async** — I tried double-buffered K/V loads in v6. Reverted because the per-tile commit_group/wait_group overhead exceeded the overlap savings at our tile size (74 ms → 95 ms decode at 32K). Worth revisiting if a future commit changes the tile sizing. 3. **Option ε prompt-step delegation** in `group_query_attention.cc` — the dispatcher checks `past_sequence_length == 0` and routes to the standard FA kernel. This is correctness-first: prompt step output is bit-equivalent to fp16, which keeps top-1 token agreement with non-TQ runs stable across cold-load tests. Cost is the wmma v6 prefill kernel never fires on step 1, but step 2+ still uses it because past_seq > 0 by then. Depends on microsoft#28560 (foundation: graph rewrite + schema). The WebGPU kernel PR and the Python tooling PR each depend on microsoft#28560 too, and don't intersect with this file set. ### Validation The host-side bit-layout tests in microsoft#28560 are necessary but not sufficient for the CUDA kernels. Kernel correctness is validated via real-model e2e runs because `onnxruntime_provider_test` does not link `libonnxruntime_providers_cuda.so` directly — the CUDA EP is loaded via `dlopen` at test time, and our kernel launcher symbols aren't exposed through that interface today. Per-layer cosine sim vs fp16 was measured by a Python harness that runs both sessions on the same prompt and compares logits; that harness sits under `tools/quantization/turboquant_kv/validate.py` in the Python tooling PR.
Contributor
|
@TimPietrusky please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.
Contributor License AgreementContribution License AgreementThis Contribution License Agreement (“Agreement”) is agreed to by the party signing below (“You”),
|
This was referenced May 19, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
CUDA kernels for the TurboQuant 4-bit KV cache path defined by #28560. When the graph transformer in #28560 has rewritten a
GroupQueryAttentionnode's KV-cache schema to the packeduint8layout, the CUDA GQA kernel now dispatches into the launchers added here.Real-model end-to-end speedups (RTX A40)
KV cache 3.56× smaller (LFM2.5, head_dim=64) / 3.76× (Qwen3-0.6B, head_dim=128). Per-layer cosine sim vs fp16 = 0.99526 across-context. Top-1 token match 7-of-9 on a 9-step decode chain.
What this PR contains
contrib_ops/cuda/bert/group_query_attention_turboquant.cuh— templated kernels:TQEncodeKernelpacks fresh K/V (Walsh-Hadamard + Lloyd-Max → 4-bit indices).TQFlashAttentionKernel(v4-lite) fused-FA on packed K with online softmax for the decode step.TQFlashAttentionQTiledKernel(v5) for continuation-prefill — ~2× prompt speedup over v4-lite at 4K.TQFlashAttentionWmmaKernel(v6) tensor-core Q*K^T at head_dim=64.contrib_ops/cuda/bert/group_query_attention_turboquant_impl.{cu,h}— launch shims, vectoriseduint4K/V smem loads (8 fp16 per HBM transaction), Option-ε dispatch that routes the prompt step to the stock FlashAttention kernel whenpast_seq == 0so model-load output is byte-stable vs fp16.contrib_ops/cuda/bert/group_query_attention.cc/.h— TQ-aware dispatch. Whenkv_quant_method != KVQuantMethod::Nonewe read the new attributes from TurboQuant KV cache (1/4): graph rewrite + schema (foundation) #28560 and route to the kernels above; otherwise the op is byte-identical to today.contrib_ops/cuda/bert/attention_data.h— three new fields inGroupQueryAttentionDatafor the codebook / Hadamard pointers and the quant-method enum.Things I'd appreciate a sharp eye on
kBlockKis coupled tokHeadDimtoday. Decoupling unlocks hd=128 (Qwen3-0.6B uses this) at v6 speed. Follow-up patch ready, kept out to minimise diff size.cp.async— tried double-buffered loads, reverted because per-tilecommit_group/wait_groupoverhead exceeded overlap savings at our tile size (74 ms → 95 ms decode at 32K). Worth revisiting if tile sizing changes.group_query_attention.cc— by design, prompt step (past_sequence_length == 0) routes to the stock FA kernel for bit-equivalent output. Cost: v6 wmma prefill never fires on the first call. Step 2+ still uses it.Validation
Host-side bit-layout tests are in #28560. CUDA kernel correctness is validated via real-model e2e runs because
onnxruntime_provider_testdoesn't linklibonnxruntime_providers_cuda.sodirectly — the EP is loaded viadlopenat test time and our launcher symbols aren't exposed there today. Per-layer cosine-sim numbers above come from a Python harness running both sessions on the same prompt and comparing logits. That harness ships in the Python tooling PR.Depends on
Does not intersect with