TurboQuant KV cache (4/4): Python reference impl + last_token_logits patcher#28563
Draft
TimPietrusky wants to merge 2 commits into
Draft
TurboQuant KV cache (4/4): Python reference impl + last_token_logits patcher#28563TimPietrusky wants to merge 2 commits into
TimPietrusky wants to merge 2 commits into
Conversation
Adds a `TurboQuantKVFusion` graph transformer that rewrites every
GroupQueryAttention node at session-create time to use a TurboQuant
4-bit packed KV cache, plus the schema, session-option keys, and CPU
helpers required for that rewrite. No kernels in this PR — they
land in follow-ups for CUDA and WebGPU.
What this PR includes:
* `core/optimizer/turboquant_kv_fusion.{cc,h}` — the L2 transformer.
Enabled by setting `optimization.turboquant_kv_method` to one of
`turboquant_4bit_nc`, `turboquant_k3v4_nc`, `turboquant_3bit_nc`.
Runs on CUDA + WebGPU EPs. Computes Lloyd-Max centroids for the
given (head_dim, key_bits) and a normalised Walsh–Hadamard matrix,
injects both as graph initializers, and mutates each GQA node's
attributes + past/present tensor types to (uint8, slot_bytes).
* `core/graph/contrib_ops/bert_defs.cc` — extends GroupQueryAttention
with the new attributes (`kv_quant_method`, `key_quant_bits`,
`value_quant_bits`, `norm_correction`) and two new optional inputs
at slots 14 / 15 for the shared k_codebook + hadamard initializers.
* `include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h`
— public option keys `optimization.turboquant_kv_method` and
`optimization.turboquant_kv_boundary`.
* `contrib_ops/cpu/bert/attention_common.h` + `attention_parameters.h`
+ `group_query_attention_helper.h` — `KVQuantMethod` enum, parameter
struct extensions, and `CheckInputs` updates so the fp16 codepath
passes through unchanged when TurboQuant isn't requested.
* `include/onnxruntime/core/framework/int3.h` — new packed `UInt3x8`
type for 3-bit cache slots. Used by the (forthcoming) 3-bit
variants.
* `test/contrib_ops/turboquant_kv_test.cc` — host-side bit-layout
tests for `UInt3x8`. Kernel-level correctness is validated by the
follow-up CUDA / WebGPU PRs.
When `optimization.turboquant_kv_method` is unset or set to "none" /
"off" the transformer doesn't fire and the graph is byte-identical
to today's output.
Design doc + reference NumPy implementation + paper-validation tests
are coming in the Python tooling PR. The CUDA kernels (16-bit accum
WMMA + 4-bit packed cache) and the WebGPU kernels (WGSL encode/decode
with an ApplyAttention fallback for browsers without Subgroups) come
in separate PRs that each depend on this one.
Benches (LFM2.5-1.2B, RTX A40, all measured):
ctx fp16 decode TQ decode speedup
4 K 6.2 s reply 6.0 s reply tied
32 K 26 s 24 s 7 %
64 K 63 s 41 s 53 %
128 K (fp16 OOM) 65 s TQ only
…patcher Reference NumPy implementation, offline graph rewriter, paper-validation tests, and a one-time model patcher that unlocks long-context inference on stock HuggingFace q4f16 ONNX exports. What this PR contains (all under `onnxruntime/python/tools/quantization/turboquant_kv/`): * `centroids.py` — Lloyd-Max solver for the K codebook. Computes the optimal scalar quantiser for `N(0, 1/d)` (the distribution of components of `(k / ||k||) @ H_norm` where `k` is fp16 and `H_norm` is the normalised Walsh-Hadamard). Deterministic given `(d, bits)`; identical to what the C++ graph transformer in microsoft#28560 injects. * `hadamard.py` — Sylvester-construction Walsh-Hadamard, scaled to `H @ H^T = I`. Same scaling as the kernels. * `packing.py` — uint8 / uint4 / uint3 bit-pack and unpack. Bit layouts match the C++ kernels in microsoft#28561 and microsoft#28562. * `quantizer.py` — `encode_keys` / `decode_keys` / `encode_values` / `decode_values`. Pure NumPy reference; used by both the offline rewriter and the validation harness. * `onnx_rewriter.py` — Python equivalent of the C++ graph transformer in microsoft#28560. Useful when users want to ship a pre-rewritten `.onnx` instead of relying on session-create-time rewriting (e.g. so a model registry can stamp a hash). * `validate.py` — paper-replication tests. 23 / 23 pass against the TurboQuant paper's published numbers. Tests are cross-validated bit-exact against vLLM's reference implementation where overlap exists. * `benchmark.py` — standalone perf bench. Used to generate the numbers in microsoft#28561 and microsoft#28562 (decode tok/s, KV cache bytes). * `last_token_logits.py` — standalone model patcher. HuggingFace causal-LM ONNX exports compute logits for *every* prompt position by default. At long contexts (S × vocab > 2^31) this trips an int32 overflow in ORT's CUDA `Cast` kernel — see microsoft#28385. This patcher inserts a `Slice` op before the LM-head MatMul so logits are computed only for the last position (the standard `logits_to_keep=1` pattern in HF transformers). One-time, ~30s, idempotent. Independent of the kernel PRs. The C++ graph transformer in microsoft#28560 makes this Python tooling optional for online use, but the rewriter + validation tests are still useful for: * Producing pre-rewritten models for environments where session options can't be set (`onnx_rewriter.py`). * Reproducing the bit-exact bit layout test the kernels rely on (`packing.py` + the C++ tests in microsoft#28560). * Validating new kernel changes against the published TurboQuant paper numbers without spinning up a full e2e benchmark (`validate.py`). * Patching the long-context cliff on stock HF exports today, before microsoft#28385 is fixed upstream (`last_token_logits.py`). Depends on microsoft#28560 only (for the schema the rewriter writes against). Does not intersect with microsoft#28561 (CUDA) or microsoft#28562 (WebGPU) — those read the schema; this one writes it.
Contributor
|
@TimPietrusky please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.
Contributor License AgreementContribution License AgreementThis Contribution License Agreement (“Agreement”) is agreed to by the party signing below (“You”),
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Reference NumPy implementation, offline graph rewriter, paper-validation tests, and a one-time model patcher that unlocks long-context inference on stock HuggingFace q4f16 ONNX exports. Standalone — depends on the schema in #28560 but not on any kernel PR.
What this PR contains
Under
onnxruntime/python/tools/quantization/turboquant_kv/:centroids.py— Lloyd-Max solver for the K codebook. Computes the optimal scalar quantiser forN(0, 1/d)(the distribution of components of(k / ||k||) @ H_normwherekis fp16 andH_normis the normalised Walsh-Hadamard). Deterministic given(d, bits); identical to what the C++ graph transformer in TurboQuant KV cache (1/4): graph rewrite + schema (foundation) #28560 injects.hadamard.py— Sylvester-construction Walsh-Hadamard, scaled toH @ H^T = I. Same scaling as the kernels.packing.py— uint8 / uint4 / uint3 bit-pack and unpack. Bit layouts match the C++ kernels in TurboQuant KV cache (2/4): CUDA kernels #28561 (CUDA) and TurboQuant KV cache (3/4): WebGPU kernels + Safari/Firefox fallback #28562 (WebGPU).quantizer.py—encode_keys/decode_keys/encode_values/decode_values. Pure NumPy reference; used by both the offline rewriter and the validation harness.onnx_rewriter.py— Python equivalent of the C++ graph transformer in TurboQuant KV cache (1/4): graph rewrite + schema (foundation) #28560. Useful when users want to ship a pre-rewritten.onnxinstead of relying on session-create-time rewriting (e.g. so a model registry can stamp a hash).validate.py— paper-replication tests. 23 / 23 pass against the TurboQuant paper's published numbers. Tests are cross-validated bit-exact against vLLM's reference implementation where overlap exists.benchmark.py— standalone perf bench. Used to generate the numbers in TurboQuant KV cache (2/4): CUDA kernels #28561 and TurboQuant KV cache (3/4): WebGPU kernels + Safari/Firefox fallback #28562.last_token_logits.py— standalone model patcher. HuggingFace causal-LM ONNX exports compute logits for every prompt position by default. At long contexts (S × vocab > 2³¹) this trips an int32 overflow in ORT's CUDACastkernel — see #28385. This patcher inserts aSliceop before the LM-head MatMul so logits are computed only for the last position (the standardlogits_to_keep=1pattern in HF transformers). One-time, ~30s, idempotent.Why ship this even though #28560 makes online rewriting work
Even with the C++ graph transformer landed:
onnx_rewriter.pyproduces pre-rewritten models for environments where session options can't be set, so a model registry can stamp a hash.packing.pyreproduces the bit-exact bit layout the kernels rely on —validate.pyuses it to compare against the published TurboQuant paper numbers without spinning up a full e2e benchmark.validate.pyis the cheapest way to validate future kernel changes — both TurboQuant KV cache (2/4): CUDA kernels #28561 (CUDA) and TurboQuant KV cache (3/4): WebGPU kernels + Safari/Firefox fallback #28562 (WebGPU) read the same packed cache layout, andvalidate.pyenforces it bit-exact.last_token_logits.pypatches the long-context cliff on stock HF exports today, before CUDA Cast kernel crashes with illegal memory access on tensors with >2^31 elements (int32 overflow) — same family as #28107 #28385 is fixed upstream.Depends on
Does not intersect with