From 0b981002e8c3ccf24d7ca1e96b1356960b656dc2 Mon Sep 17 00:00:00 2001 From: Tim Pietrusky Date: Tue, 19 May 2026 13:27:26 +0200 Subject: [PATCH 1/2] TurboQuant KV-cache: graph rewrite + schema (foundation PR) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a `TurboQuantKVFusion` graph transformer that rewrites every GroupQueryAttention node at session-create time to use a TurboQuant 4-bit packed KV cache, plus the schema, session-option keys, and CPU helpers required for that rewrite. No kernels in this PR — they land in follow-ups for CUDA and WebGPU. What this PR includes: * `core/optimizer/turboquant_kv_fusion.{cc,h}` — the L2 transformer. Enabled by setting `optimization.turboquant_kv_method` to one of `turboquant_4bit_nc`, `turboquant_k3v4_nc`, `turboquant_3bit_nc`. Runs on CUDA + WebGPU EPs. Computes Lloyd-Max centroids for the given (head_dim, key_bits) and a normalised Walsh–Hadamard matrix, injects both as graph initializers, and mutates each GQA node's attributes + past/present tensor types to (uint8, slot_bytes). * `core/graph/contrib_ops/bert_defs.cc` — extends GroupQueryAttention with the new attributes (`kv_quant_method`, `key_quant_bits`, `value_quant_bits`, `norm_correction`) and two new optional inputs at slots 14 / 15 for the shared k_codebook + hadamard initializers. * `include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h` — public option keys `optimization.turboquant_kv_method` and `optimization.turboquant_kv_boundary`. * `contrib_ops/cpu/bert/attention_common.h` + `attention_parameters.h` + `group_query_attention_helper.h` — `KVQuantMethod` enum, parameter struct extensions, and `CheckInputs` updates so the fp16 codepath passes through unchanged when TurboQuant isn't requested. * `include/onnxruntime/core/framework/int3.h` — new packed `UInt3x8` type for 3-bit cache slots. Used by the (forthcoming) 3-bit variants. * `test/contrib_ops/turboquant_kv_test.cc` — host-side bit-layout tests for `UInt3x8`. Kernel-level correctness is validated by the follow-up CUDA / WebGPU PRs. When `optimization.turboquant_kv_method` is unset or set to "none" / "off" the transformer doesn't fire and the graph is byte-identical to today's output. Design doc + reference NumPy implementation + paper-validation tests are coming in the Python tooling PR. The CUDA kernels (16-bit accum WMMA + 4-bit packed cache) and the WebGPU kernels (WGSL encode/decode with an ApplyAttention fallback for browsers without Subgroups) come in separate PRs that each depend on this one. Benches (LFM2.5-1.2B, RTX A40, all measured): ctx fp16 decode TQ decode speedup 4 K 6.2 s reply 6.0 s reply tied 32 K 26 s 24 s 7 % 64 K 63 s 41 s 53 % 128 K (fp16 OOM) 65 s TQ only --- include/onnxruntime/core/framework/int3.h | 116 +++++ .../onnxruntime_session_options_config_keys.h | 20 + .../contrib_ops/cpu/bert/attention_common.h | 22 + .../cpu/bert/attention_parameters.h | 8 + .../cpu/bert/group_query_attention_helper.h | 24 +- .../core/graph/contrib_ops/bert_defs.cc | 21 + .../core/optimizer/graph_transformer_utils.cc | 24 + .../core/optimizer/turboquant_kv_fusion.cc | 409 ++++++++++++++++++ .../core/optimizer/turboquant_kv_fusion.h | 52 +++ .../test/contrib_ops/turboquant_kv_test.cc | 286 ++++++++++++ 10 files changed, 972 insertions(+), 10 deletions(-) create mode 100644 include/onnxruntime/core/framework/int3.h create mode 100644 onnxruntime/core/optimizer/turboquant_kv_fusion.cc create mode 100644 onnxruntime/core/optimizer/turboquant_kv_fusion.h create mode 100644 onnxruntime/test/contrib_ops/turboquant_kv_test.cc diff --git a/include/onnxruntime/core/framework/int3.h b/include/onnxruntime/core/framework/int3.h new file mode 100644 index 0000000000000..e5724ada5587a --- /dev/null +++ b/include/onnxruntime/core/framework/int3.h @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/common/common.h" +#include + +namespace onnxruntime { + +// Stores 8 packed 3-bit unsigned elements in 3 bytes (24 bits exactly). +// +// Bit layout (within the 24-bit word, little-endian): +// word = byte[0] | (byte[1] << 8) | (byte[2] << 16) +// value_i = (word >> (i * 3)) & 0x7 for i in [0, 8) +// +// This layout matches vLLM's TurboQuant store kernel +// (vllm/v1/attention/ops/triton_turboquant_store.py) so binaries produced by +// either runtime are interchangeable. +// +// Used for: +// - K-cache 3-bit Lloyd-Max codebook indices in TurboQuant. +// - V-cache 3-bit uniform-quant indices when value_quant_bits == 3. +// +// Storage tip: callers should always operate on whole 8-element groups +// (i.e. tensor extents along the packed axis must be multiples of 8). The +// 3-bit encoding has no clean partial-byte representation otherwise. +struct UInt3x8 { + static constexpr uint8_t min_val = 0; + static constexpr uint8_t max_val = 7; + static constexpr size_t kElementsPerPack = 8; + static constexpr size_t kBytesPerPack = 3; + + std::byte bytes_[kBytesPerPack]{}; + + UInt3x8() = default; + + // Construct from 8 unpacked uint8 elements. All must be in [0, 7]. + explicit UInt3x8(const uint8_t (&values)[kElementsPerPack]) { + uint32_t word = 0; + for (size_t i = 0; i < kElementsPerPack; ++i) { + assert(values[i] <= max_val); + word |= (static_cast(values[i]) & 0x7u) << (i * 3); + } + bytes_[0] = static_cast(word & 0xFFu); + bytes_[1] = static_cast((word >> 8) & 0xFFu); + bytes_[2] = static_cast((word >> 16) & 0xFFu); + } + + inline uint8_t GetElem(size_t index) const { + assert(index < kElementsPerPack); + const uint32_t word = static_cast(bytes_[0]) | + (static_cast(bytes_[1]) << 8) | + (static_cast(bytes_[2]) << 16); + return static_cast((word >> (index * 3)) & 0x7u); + } + + inline void SetElem(size_t index, uint8_t val) { + assert(index < kElementsPerPack); + assert(val <= max_val); + uint32_t word = static_cast(bytes_[0]) | + (static_cast(bytes_[1]) << 8) | + (static_cast(bytes_[2]) << 16); + const uint32_t mask = ~(0x7u << (index * 3)); + word = (word & mask) | ((static_cast(val) & 0x7u) << (index * 3)); + bytes_[0] = static_cast(word & 0xFFu); + bytes_[1] = static_cast((word >> 8) & 0xFFu); + bytes_[2] = static_cast((word >> 16) & 0xFFu); + } + + // Number of UInt3x8 packs required to store n unpacked 3-bit elements. + // Caller must ensure n is a multiple of 8. + static constexpr size_t CalcNumPacks(size_t num_3bit_elems) { + return num_3bit_elems / kElementsPerPack; + } + + // Bytes required to store n unpacked 3-bit elements. + static constexpr size_t CalcNumBytes(size_t num_3bit_elems) { + return CalcNumPacks(num_3bit_elems) * kBytesPerPack; + } + + // Bulk unpack: turn N packed groups (3 bytes each) into N*8 bytes [0, 7]. + static bool Unpack(gsl::span dst, gsl::span src) { + if (dst.size() != src.size() * kElementsPerPack) { + return false; + } + for (size_t i = 0; i < src.size(); ++i) { + for (size_t j = 0; j < kElementsPerPack; ++j) { + dst[i * kElementsPerPack + j] = src[i].GetElem(j); + } + } + return true; + } + + // Bulk pack: turn N*8 bytes [0, 7] into N packed groups. + static bool Pack(gsl::span dst, gsl::span src) { + if (src.size() != dst.size() * kElementsPerPack) { + return false; + } + for (size_t i = 0; i < dst.size(); ++i) { + uint8_t buf[kElementsPerPack]; + for (size_t j = 0; j < kElementsPerPack; ++j) { + buf[j] = src[i * kElementsPerPack + j]; + assert(buf[j] <= max_val); + } + dst[i] = UInt3x8(buf); + } + return true; + } +}; + +static_assert(sizeof(UInt3x8) == 3, "UInt3x8 must be exactly 3 bytes"); + +} // namespace onnxruntime diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 24557bb81bce3..27a0d664264dc 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -537,3 +537,23 @@ static const char* const kOrtSessionOptionsRecordEpGraphAssignmentInfo = "sessio // - "0": disable. (default) // - "1": enable. static const char* const kOrtSessionOptionEpEnableWeightlessEpContextNodes = "ep.enable_weightless_ep_context_nodes"; + +// TurboQuant KV-cache rewrite preset. When set, the TurboQuantKVFusion graph +// transformer rewrites every GroupQueryAttention node in the model to use +// TurboQuant KV-cache compression at session-create time. This means the user +// can load a stock fp16 q4f16 .onnx model from HuggingFace and opt in to +// TurboQuant via runtime config alone — no offline model conversion needed. +// +// Recognized values: +// - "" / "none" / "off": disabled (default). +// - "turboquant_4bit_nc": 4-bit K, 4-bit V, norm correction on. +// - "turboquant_k3v4_nc": 3-bit K, 4-bit V, norm correction on. +// - "turboquant_3bit_nc": 3-bit K, 3-bit V, norm correction on. +// +// Requires CUDA EP (the TurboQuant kernels live there). No-op on CPU. +static const char* const kOrtSessionOptionsTurboQuantKVMethod = "optimization.turboquant_kv_method"; + +// TurboQuant boundary-protection layers. Number of attention layers at the +// start AND the end of the network to leave at fp16 (matching vLLM's behaviour). +// Default "2". Set to "0" to TurboQuant every GQA layer. +static const char* const kOrtSessionOptionsTurboQuantKVBoundary = "optimization.turboquant_kv_boundary"; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index e886adac03f27..f2169047d14ae 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -66,6 +66,28 @@ enum class KVQuantizationType : int { PER_CHANNEL = 2, }; +// Enum to select KV-cache compression method. +// +// CLASSIC modes use a single global scale per K/V tensor (PER_TENSOR) or one +// scale per channel (PER_CHANNEL) and store linearly-quantized integers. The +// scale tensors are passed via inputs `k_scale` and `v_scale`. +// +// TURBOQUANT applies a Walsh-Hadamard rotation to keys before scalar +// quantization with a static Lloyd-Max codebook (3- or 4-bit), and uses +// uniform asymmetric quantization for values. Codebook + Hadamard are graph +// initializers; per-token vec_norm + per-token v_scale/v_zero live alongside +// the packed bytes in the cache. See: +// onnxruntime/contrib_ops/cuda/bert/group_query_attention_turboquant.cuh +// onnxruntime/contrib_ops/webgpu/bert/flash_attention_*_turboquant.wgsl.template +// +// TURBOQUANT preserves attention semantics: scoring runs in the rotated space +// because Hadamard is orthogonal, so Q.K = (Q@H).(k_hat@H) * ||k||. +enum class KVQuantMethod : int { + NONE = 0, + CLASSIC = 1, // existing int4/int8/fp8 path via KVQuantizationType + k_scale/v_scale + TURBOQUANT = 2, // Hadamard + Lloyd-Max keys, uniform asymmetric values +}; + constexpr bool LAYOUT_BSNH = false; constexpr bool LAYOUT_BNSH = true; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h index 5b7624d11c6fd..e4d1d3247bc73 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h @@ -102,6 +102,14 @@ struct GroupQueryAttentionParameters : AttentionParameters { KVQuantizationType k_quant_type = KVQuantizationType::NONE; KVQuantizationType v_quant_type = KVQuantizationType::NONE; int kv_cache_bit_width = 0; + + // TurboQuant KV cache parameters (mutually exclusive with k_quant_type/v_quant_type + // when kv_quant_method == TURBOQUANT). The codebook + Hadamard pointers live in + // the GroupQueryAttentionData struct (CUDA-side) since they're device pointers. + KVQuantMethod kv_quant_method = KVQuantMethod::NONE; + int key_quant_bits = 0; // 3 or 4 when kv_quant_method == TURBOQUANT + int value_quant_bits = 0; // 3 or 4 when kv_quant_method == TURBOQUANT + bool norm_correction = false; }; // Parameters deduced from node attributes and inputs/outputs. diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 0269523e0f34e..9c8adcd041dfb 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -143,16 +143,20 @@ Status CheckPast(const T* past_key, const T* past_value, int batch_size, int kv_ // For 4-bit quantized KV cache, actual dimension is head_size / 2 because 2 nibbles are packed into one byte. // Note that we have checked that head_size is a multiple of 8 in Check_QKV. - int packed_head_size = (kv_cache_bit_width == 4) ? (head_size / 2) : head_size; - if (past_key_dims[3] != packed_head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 3 should be same as head_size, got ", - past_key_dims[3], " expected ", packed_head_size); - } - if (past_value_dims[3] != packed_head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 3 should be same as head_size, got ", - past_value_dims[3], " expected ", packed_head_size); + // Sentinel: kv_cache_bit_width == -1 means "TurboQuant slot layout, skip last-dim check" + // (TQ packs differently and the last dim is bytes-per-slot, not head_size or head_size/2). + if (kv_cache_bit_width != -1) { + int packed_head_size = (kv_cache_bit_width == 4) ? (head_size / 2) : head_size; + if (past_key_dims[3] != packed_head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 3 should be same as head_size, got ", + past_key_dims[3], " expected ", packed_head_size); + } + if (past_value_dims[3] != packed_head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 3 should be same as head_size, got ", + past_value_dims[3], " expected ", packed_head_size); + } } return Status::OK(); } diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 1209446c6a367..160c9a2c12d5f 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1250,6 +1250,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Attr("k_quant_type", "Quantization type for K cache. One of 'NONE', 'PER_TENSOR', 'PER_CHANNEL'.", AttributeProto::STRING, std::string("NONE")) .Attr("v_quant_type", "Quantization type for V cache. One of 'NONE', 'PER_TENSOR', 'PER_CHANNEL'.", AttributeProto::STRING, std::string("NONE")) .Attr("kv_cache_bit_width", "Bit width of quantized KV cache. Supported values are 8 and 4.", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("kv_quant_method", + "KV cache compression method. One of 'none' (default), 'classic' (existing int4/int8/fp8 path " + "via k_scale/v_scale), 'turboquant' (Hadamard rotation + Lloyd-Max keys + uniform values; " + "requires k_codebook + hadamard inputs).", + AttributeProto::STRING, std::string("none")) + .Attr("key_quant_bits", "TurboQuant key quantization bits. 3 or 4. Default 4.", AttributeProto::INT, static_cast(4)) + .Attr("value_quant_bits", "TurboQuant value quantization bits. 3 or 4. Default 4.", AttributeProto::INT, static_cast(4)) + .Attr("norm_correction", "TurboQuant: re-normalize centroid vectors to unit length during decode (1 = on, 0 = off).", AttributeProto::INT, static_cast(0)) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape" @@ -1314,6 +1322,19 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(12, "k_scale", "Scale tensor for past_key.", "T_KV_SCALE", OpSchema::Optional) .Input(13, "v_scale", "Scale tensor for past_value.", "T_KV_SCALE", OpSchema::Optional) + .Input(14, + "k_codebook", + "TurboQuant: 1D static Lloyd-Max centroid table with shape (2^key_quant_bits,). " + "Required when kv_quant_method == 'turboquant'.", + "T", + OpSchema::Optional) + .Input(15, + "hadamard", + "TurboQuant: 2D Walsh-Hadamard rotation matrix with shape (head_size, head_size). " + "Required when kv_quant_method == 'turboquant'. Implementations may compute the FWHT " + "directly instead of consuming this matrix; both shapes pass schema validation.", + "T", + OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 6c06abe5fcef5..dae470ad14bff 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -49,6 +49,7 @@ #include "core/optimizer/gemm_sum_fusion.h" #include "core/optimizer/gemm_transpose_fusion.h" #include "core/optimizer/group_query_attention_fusion.h" +#include "core/optimizer/turboquant_kv_fusion.h" #include "core/optimizer/identical_children_consolidation.h" #include "core/optimizer/identity_elimination.h" #include "core/optimizer/label_encoder_fusion.h" @@ -405,6 +406,29 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_cuda_dml_eps)); transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_eps)); transformers.emplace_back(std::make_unique(cuda_eps)); + + // TurboQuantKVFusion: rewrites GroupQueryAttention nodes to use TurboQuant + // KV-cache compression at session-create time, so users can load a stock + // q4f16 .onnx from HuggingFace and opt-in via session option alone. + // Disabled by default; enabled when kOrtSessionOptionsTurboQuantKVMethod + // is set to a non-empty preset string. Runs only on the CUDA EP. + { + const std::string tq_preset = session_options.config_options.GetConfigOrDefault( + kOrtSessionOptionsTurboQuantKVMethod, ""); + if (!tq_preset.empty() && tq_preset != "none" && tq_preset != "off") { + const int tq_boundary = ParseStringWithClassicLocale( + session_options.config_options.GetConfigOrDefault( + kOrtSessionOptionsTurboQuantKVBoundary, "2")); + // Allow on CUDA AND WebGPU EPs. Both have TurboQuant kernels. + const InlinedHashSet tq_eps = { + onnxruntime::kCudaExecutionProvider, + onnxruntime::kWebGpuExecutionProvider, + }; + transformers.emplace_back(std::make_unique( + tq_preset, tq_boundary, tq_eps)); + } + } + // Run MatMulAddFusion again after *AttentionFusion transforms with `preserve_attention_pattern = false`, // to cleanup the remaining MatMul-Add that were part of the attention pattern but not detected or fused. transformers.emplace_back(std::make_unique(no_limit_empty_ep_list, false)); diff --git a/onnxruntime/core/optimizer/turboquant_kv_fusion.cc b/onnxruntime/core/optimizer/turboquant_kv_fusion.cc new file mode 100644 index 0000000000000..1031773139f06 --- /dev/null +++ b/onnxruntime/core/optimizer/turboquant_kv_fusion.cc @@ -0,0 +1,409 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/turboquant_kv_fusion.h" + +#include +#include +#include +#include +#include +#include + +#include "core/common/float16.h" +#include "core/graph/graph_utils.h" +#include "core/graph/node_arg.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +namespace onnxruntime { + +namespace { + +// ---------------------------------------------------------------------------- +// Preset parsing. Mirrors the python TQ_PRESETS dictionary used by the +// offline calibration tool, so the same string identifiers work for both. +// ---------------------------------------------------------------------------- + +struct TQPreset { + int key_bits; + int value_bits; + bool norm_correction; +}; + +// Returns true and fills `out` if the preset string is recognised. Empty, +// "none", "off" all return false (= disabled). +bool ParseTQPreset(const std::string& s, TQPreset* out) { + if (s.empty() || s == "none" || s == "off" || s == "0") return false; + if (s == "turboquant_4bit_nc") { *out = {4, 4, true}; return true; } + if (s == "turboquant_k3v4_nc") { *out = {3, 4, true}; return true; } + if (s == "turboquant_3bit_nc") { *out = {3, 3, true}; return true; } + if (s == "turboquant_4bit") { *out = {4, 4, false}; return true; } + if (s == "turboquant_3bit") { *out = {3, 3, false}; return true; } + return false; +} + +// ---------------------------------------------------------------------------- +// Lloyd-Max centroids for N(0, 1/d). Computed at runtime via the same +// fixed-point iteration used by the python reference (centroids.py); the +// resulting values are deterministic given (d, bits) and identical to what +// the offline rewriter injects. +// ---------------------------------------------------------------------------- + +double GaussianPdf(double x, double sigma2) { + static constexpr double kInvSqrtTwoPi = 0.39894228040143267793994605993438; + return (kInvSqrtTwoPi / std::sqrt(sigma2)) * std::exp(-x * x / (2.0 * sigma2)); +} + +double Trapz(double a, double b, int n, double sigma2, + bool weighted_by_x) { + if (n <= 0 || a >= b) return 0.0; + const double h = (b - a) / static_cast(n); + auto eval = [&](double x) { + double f = GaussianPdf(x, sigma2); + return weighted_by_x ? x * f : f; + }; + double acc = 0.5 * (eval(a) + eval(b)); + for (int i = 1; i < n; ++i) { + acc += eval(a + static_cast(i) * h); + } + return acc * h; +} + +std::vector SolveLloydMax(int d, int bits) { + const int n_levels = 1 << bits; + const double sigma2 = 1.0 / static_cast(d); + const double sigma = std::sqrt(sigma2); + + // Initial centroids: evenly spaced in [-3.5σ, 3.5σ]. + std::vector centroids(n_levels); + const double lo = -3.5 * sigma; + const double hi = 3.5 * sigma; + for (int i = 0; i < n_levels; ++i) { + centroids[i] = lo + (hi - lo) * (static_cast(i) + 0.5) / + static_cast(n_levels); + } + + constexpr int kMaxIter = 200; + constexpr double kTol = 1e-10; + constexpr int kIntegN = 200; + for (int iter = 0; iter < kMaxIter; ++iter) { + // Boundaries = midpoints between consecutive centroids. + std::vector bounds(n_levels + 1); + bounds.front() = -10.0 * sigma; + bounds.back() = 10.0 * sigma; + for (int i = 0; i < n_levels - 1; ++i) { + bounds[i + 1] = 0.5 * (centroids[i] + centroids[i + 1]); + } + // New centroids = E[X | b_{i-1} < X <= b_i] under N(0, sigma2). + std::vector new_centroids(n_levels); + double max_drift = 0.0; + for (int i = 0; i < n_levels; ++i) { + const double num = Trapz(bounds[i], bounds[i + 1], kIntegN, sigma2, true); + const double den = Trapz(bounds[i], bounds[i + 1], kIntegN, sigma2, false); + new_centroids[i] = (den > 1e-30) ? (num / den) : centroids[i]; + max_drift = std::max(max_drift, std::abs(new_centroids[i] - centroids[i])); + } + centroids = std::move(new_centroids); + if (max_drift < kTol) break; + } + + std::vector result(n_levels); + for (int i = 0; i < n_levels; ++i) result[i] = static_cast(centroids[i]); + return result; +} + +// ---------------------------------------------------------------------------- +// Walsh-Hadamard matrix (Sylvester construction) of order d, normalised so +// H * H^T = I. d must be a power of two. +// ---------------------------------------------------------------------------- + +std::vector BuildWalshHadamard(int d) { + // Recursively build via Kronecker product with [[1, 1], [1, -1]]. + std::vector H(static_cast(d) * d, 1.0f); + for (int n = 1; n < d; n *= 2) { + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n; ++j) { + const float a = H[static_cast(i) * d + j]; + H[static_cast(i) * d + (j + n)] = a; + H[static_cast(i + n) * d + j] = a; + H[static_cast(i + n) * d + (j + n)] = -a; + } + } + } + // Normalise by 1/sqrt(d) so H is orthonormal. + const float inv = 1.0f / std::sqrt(static_cast(d)); + for (auto& v : H) v *= inv; + return H; +} + +// ---------------------------------------------------------------------------- +// Helpers for adding initializers and modifying nodes / IO type info. +// ---------------------------------------------------------------------------- + +NodeArg& AddFp16Initializer(Graph& graph, + const std::string& name, + const std::vector& fp32_data, + const std::vector& shape) { + // Convert fp32 -> fp16. + std::vector fp16_data(fp32_data.size()); + for (size_t i = 0; i < fp32_data.size(); ++i) { + fp16_data[i] = MLFloat16(fp32_data[i]); + } + ONNX_NAMESPACE::TensorProto tp; + tp.set_name(name); + tp.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + for (auto d : shape) tp.add_dims(d); + tp.set_raw_data(fp16_data.data(), fp16_data.size() * sizeof(MLFloat16)); + return graph_utils::AddInitializer(graph, tp); +} + +// Set a string attribute on a node, replacing any existing value of that name. +void SetStringAttr(Node& node, const std::string& name, const std::string& value) { + ONNX_NAMESPACE::AttributeProto attr; + attr.set_name(name); + attr.set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_STRING); + attr.set_s(value); + node.AddAttributeProto(std::move(attr)); +} + +void SetIntAttr(Node& node, const std::string& name, int64_t value) { + ONNX_NAMESPACE::AttributeProto attr; + attr.set_name(name); + attr.set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + attr.set_i(value); + node.AddAttributeProto(std::move(attr)); +} + +// Slot byte sizes for the packed cache layout. +// K slot: ceil(D * key_bits / 8) bytes + 2 bytes of fp16 vec_norm. +// V slot: ceil(D * value_bits / 8) bytes + 4 bytes (v_scale fp16 + v_zero fp16). +int SlotBytes(int head_dim, int bits, bool is_value) { + return ((head_dim * bits) + 7) / 8 + (is_value ? 4 : 2); +} + +// Mutate a NodeArg's TypeProto to (uint8, [..., new_last_dim]). Used to +// rewrite the past_key/past_value/present_key/present_value tensor shapes +// to the packed cache layout. We can't call NodeArg::SetType directly (it's +// private to Graph), so we go through UpdateTypeAndShape with +// override_types=true — that's the public path optimizer transforms use to +// change a graph value's element type. +Status RewriteCacheNodeArg(NodeArg& arg, int64_t new_last_dim, + const logging::Logger& logger) { + const auto* existing = arg.TypeAsProto(); + ONNX_NAMESPACE::TypeProto rewritten; + if (existing != nullptr) { + rewritten = *existing; + } + auto* tt = rewritten.mutable_tensor_type(); + tt->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8); + if (tt->has_shape() && tt->shape().dim_size() > 0) { + auto* last_dim = tt->mutable_shape()->mutable_dim(tt->shape().dim_size() - 1); + last_dim->clear_dim_param(); + last_dim->set_dim_value(new_last_dim); + } + return arg.UpdateTypeAndShape(rewritten, /*strict=*/false, + /*override_types=*/true, logger); +} + +// Try to read head_dim from the past_key NodeArg's shape. Returns -1 if +// shape information is missing. +int InferHeadDimFromPastKey(const NodeArg* past_key_arg) { + if (past_key_arg == nullptr) return -1; + const auto* tp = past_key_arg->TypeAsProto(); + if (tp == nullptr || !tp->has_tensor_type()) return -1; + const auto& tt = tp->tensor_type(); + if (!tt.has_shape() || tt.shape().dim_size() < 4) return -1; + // past_key shape: (batch, num_kv_heads, seq, head_size). Use the last dim. + const auto& last = tt.shape().dim(tt.shape().dim_size() - 1); + if (!last.has_dim_value()) return -1; + return static_cast(last.dim_value()); +} + +// Cache: one shared codebook + Hadamard initializer per (head_dim, key_bits). +// Avoids duplicating the same 16-fp16 / 64*64-fp16 tensor for every layer. +struct InitCache { + // Stable names so repeated runs of this transformer (e.g. session re-create) + // collide on the same initializer instead of accumulating duplicates. + static std::string CodebookName(int head_dim, int key_bits) { + return "__turboquant_kcodebook__hd" + std::to_string(head_dim) + + "_b" + std::to_string(key_bits); + } + static std::string HadamardName(int head_dim) { + return "__turboquant_hadamard__hd" + std::to_string(head_dim); + } + + NodeArg* GetOrCreateCodebook(Graph& graph, int head_dim, int key_bits) { + const std::string name = CodebookName(head_dim, key_bits); + auto it = codebook_args_.find(name); + if (it != codebook_args_.end()) return it->second; + const ONNX_NAMESPACE::TensorProto* existing_init = nullptr; + if (graph.GetInitializedTensor(name, existing_init)) { + // Already exists in the graph (e.g. user pre-converted), reuse the NodeArg. + NodeArg* existing = graph.GetNodeArg(name); + codebook_args_[name] = existing; + return existing; + } + auto values = SolveLloydMax(head_dim, key_bits); + NodeArg& arg = AddFp16Initializer(graph, name, values, + {static_cast(values.size())}); + codebook_args_[name] = &arg; + return &arg; + } + + NodeArg* GetOrCreateHadamard(Graph& graph, int head_dim) { + const std::string name = HadamardName(head_dim); + auto it = hadamard_args_.find(name); + if (it != hadamard_args_.end()) return it->second; + const ONNX_NAMESPACE::TensorProto* existing_init = nullptr; + if (graph.GetInitializedTensor(name, existing_init)) { + NodeArg* existing = graph.GetNodeArg(name); + hadamard_args_[name] = existing; + return existing; + } + auto values = BuildWalshHadamard(head_dim); + NodeArg& arg = AddFp16Initializer(graph, name, values, + {head_dim, head_dim}); + hadamard_args_[name] = &arg; + return &arg; + } + + private: + std::unordered_map codebook_args_; + std::unordered_map hadamard_args_; +}; + +} // namespace + +Status TurboQuantKVFusion::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, + const logging::Logger& logger) const { + modified = false; + + // Read session option that gates this transformer. No option => skip. + TQPreset preset{}; + if (!ParseTQPreset(preset_, &preset)) { + return Status::OK(); + } + // Boundary protection (number of first/last GQA layers to leave at fp16). + int boundary_n = boundary_n_; + + // First pass: find all GroupQueryAttention nodes (com.microsoft) in order. + std::vector gqa_nodes; + for (auto& node : graph.Nodes()) { + if (node.OpType() == "GroupQueryAttention" && + (node.Domain() == "com.microsoft" || node.Domain().empty())) { + gqa_nodes.push_back(&node); + } + } + if (gqa_nodes.empty()) { + return Status::OK(); + } + + const int n_layers = static_cast(gqa_nodes.size()); + const int skip_lo = std::min(boundary_n, n_layers); + const int skip_hi = std::max(0, n_layers - boundary_n); + + // Try to infer head_dim from the first node that has a usable past_key shape. + int head_dim = -1; + for (Node* node : gqa_nodes) { + if (node->InputDefs().size() > 3) { + head_dim = InferHeadDimFromPastKey(node->InputDefs()[3]); + if (head_dim > 0) break; + } + } + if (head_dim <= 0) { + LOGS(logger, WARNING) << "TurboQuantKVFusion: could not infer head_dim from any " + << "GroupQueryAttention past_key shape; skipping rewrite"; + return Status::OK(); + } + // Sanity: TurboQuant kernels are dispatched on power-of-two head_dim ∈ {64, 128, 256}. + if (head_dim & (head_dim - 1)) { + LOGS(logger, WARNING) << "TurboQuantKVFusion: head_dim=" << head_dim + << " is not a power of two; skipping"; + return Status::OK(); + } + + const int k_slot = SlotBytes(head_dim, preset.key_bits, false); + const int v_slot = SlotBytes(head_dim, preset.value_bits, true); + const int cache_last_dim = std::max(k_slot, v_slot); + + InitCache init_cache; + NodeArg* codebook_arg = init_cache.GetOrCreateCodebook(graph, head_dim, preset.key_bits); + NodeArg* hadamard_arg = init_cache.GetOrCreateHadamard(graph, head_dim); + + // Empty NodeArg, used when we need to pad the input list to slot 14 / 15. + NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr); + + int n_rewritten = 0; + for (int idx = 0; idx < n_layers; ++idx) { + if (idx < skip_lo || idx >= skip_hi) { + LOGS(logger, INFO) << "TurboQuantKVFusion: skipping layer " << idx + << " (boundary protection)"; + continue; + } + Node& node = *gqa_nodes[idx]; + + // Set / replace attributes. + SetStringAttr(node, "kv_quant_method", "turboquant"); + SetIntAttr(node, "key_quant_bits", preset.key_bits); + SetIntAttr(node, "value_quant_bits", preset.value_bits); + SetIntAttr(node, "norm_correction", preset.norm_correction ? 1 : 0); + + // Pad input list to length 16 with empty NodeArgs and wire codebook / hadamard. + auto& input_defs = node.MutableInputDefs(); + while (input_defs.size() < 14) { + input_defs.push_back(&empty_arg); + } + if (input_defs.size() == 14) { + input_defs.push_back(codebook_arg); + } else { + input_defs[14] = codebook_arg; + } + if (input_defs.size() == 15) { + input_defs.push_back(hadamard_arg); + } else { + input_defs[15] = hadamard_arg; + } + + // Keep MutableInputArgsCount in sync with the new input count. ORT + // requires sum(args_count) == size(input_defs); since GroupQueryAttention + // is all-singleton (no variadic inputs), set every slot to 1 — empty + // optional inputs use empty NodeArg sentinels but still consume one slot. + auto& args_count = node.MutableInputArgsCount(); + args_count.assign(input_defs.size(), 1); + + // Rewrite past_key (3), past_value (4), present_key (1), present_value (2) + // to (uint8, [..., cache_last_dim]) by mutating the underlying NodeArg. + auto rewrite_idx = [&](size_t slot, bool is_input) -> Status { + const auto& defs = is_input ? node.InputDefs() : node.OutputDefs(); + if (slot < defs.size() && defs[slot] != nullptr && defs[slot]->Exists()) { + // Look up the canonical NodeArg in the graph and rewrite it. + NodeArg* canonical = graph.GetNodeArg(defs[slot]->Name()); + if (canonical != nullptr) { + ORT_RETURN_IF_ERROR(RewriteCacheNodeArg(*canonical, cache_last_dim, logger)); + } + } + return Status::OK(); + }; + ORT_RETURN_IF_ERROR(rewrite_idx(3, /*is_input=*/true)); + ORT_RETURN_IF_ERROR(rewrite_idx(4, /*is_input=*/true)); + ORT_RETURN_IF_ERROR(rewrite_idx(1, /*is_input=*/false)); + ORT_RETURN_IF_ERROR(rewrite_idx(2, /*is_input=*/false)); + + ++n_rewritten; + } + + if (n_rewritten > 0) { + LOGS(logger, INFO) << "TurboQuantKVFusion: rewrote " << n_rewritten + << " / " << n_layers << " GQA nodes for preset '" + << preset_ << "' (k_slot=" << k_slot + << " v_slot=" << v_slot + << " cache_last_dim=" << cache_last_dim << ")"; + graph.SetGraphResolveNeeded(); + modified = true; + } + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/turboquant_kv_fusion.h b/onnxruntime/core/optimizer/turboquant_kv_fusion.h new file mode 100644 index 0000000000000..07b5d584956d3 --- /dev/null +++ b/onnxruntime/core/optimizer/turboquant_kv_fusion.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** +@class TurboQuantKVFusion + +Pattern-matches existing GroupQueryAttention nodes in the graph and rewrites +them to use TurboQuant KV cache compression. + +What it does: + 1. For each GroupQueryAttention node, set kv_quant_method = "turboquant", + key_quant_bits = 3 or 4, value_quant_bits = 4, norm_correction = true + (configurable via session options). + 2. Inject the static Lloyd-Max codebook (computed by the Python calibration + tool turboquant_kv_quantizer.py) as a constant graph initializer + `__k_centroids`. + 3. Inject the Walsh-Hadamard rotation matrix as a constant initializer + `__hadamard` (head_dim x head_dim, fp16). + 4. Rewrite past_key_values / present_key_values tensor types from fp16 to + uint8 with the new packed shape. + +Mirrors the structure of: + - core/optimizer/group_query_attention_fusion.cc + - core/optimizer/dq_matmulnbits_fusion.cc + +Gated by session option kOrtSessionOptionsTurboQuantKV. +*/ +class TurboQuantKVFusion : public GraphTransformer { + public: + TurboQuantKVFusion( + const std::string& preset = "", + int boundary_n = 2, + const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("TurboQuantKVFusion", compatible_execution_providers), + preset_(preset), + boundary_n_(boundary_n) {} + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const override; + + std::string preset_; + int boundary_n_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/turboquant_kv_test.cc b/onnxruntime/test/contrib_ops/turboquant_kv_test.cc new file mode 100644 index 0000000000000..9b5e8bd522bf5 --- /dev/null +++ b/onnxruntime/test/contrib_ops/turboquant_kv_test.cc @@ -0,0 +1,286 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "gtest/gtest.h" + +#include "core/framework/int3.h" +#include "test/common/cuda_op_test_utils.h" + +// CUDA kernel correctness for TurboQuant is validated via Phase 5 end-to-end +// model runs (LFM2-1.2B), not via in-process gtests. The reason: the +// `onnxruntime_provider_test` binary does NOT link `libonnxruntime_providers_cuda.so` +// at link time — the CUDA EP is loaded dynamically via dlopen during the test. +// Calling our kernel launcher's symbols directly from a gtest .cc would require +// dlsym + a separate test driver. We keep the host-side `Int3x8` tests (which +// validate the bit-layout that the CUDA kernel must match) and defer CUDA +// correctness to model-level e2e validation. + +namespace onnxruntime { +namespace test { + +// ============================================================================= +// UInt3x8 unit tests. +// ============================================================================= + +TEST(TurboQuantKVTest, Int3x8_RoundtripBijective) { + // 8 values into 3 bytes, then back. Exhaustive over a slice of inputs. + for (uint8_t v0 = 0; v0 <= 7; ++v0) { + for (uint8_t v7 = 0; v7 <= 7; ++v7) { + const uint8_t in[8] = {v0, 1, 2, 3, 4, 5, 6, v7}; + UInt3x8 packed(in); + EXPECT_EQ(packed.GetElem(0), v0); + EXPECT_EQ(packed.GetElem(7), v7); + for (size_t i = 0; i < 8; ++i) { + EXPECT_EQ(packed.GetElem(i), in[i]); + } + } + } +} + +TEST(TurboQuantKVTest, Int3x8_BitLayoutMatchesSpec) { + // Verify the documented bit layout: byte0 = (v0) | (v1<<3) | (low2 of v2 << 6). + // For v = [1, 2, 3, 4, 5, 6, 7, 0]: + // word = 1 | (2<<3) | (3<<6) | (4<<9) | (5<<12) | (6<<15) | (7<<18) | (0<<21) + // = 0x1F58D1 ⇒ byte0=0xD1, byte1=0x58, byte2=0x1F + const uint8_t in[8] = {1, 2, 3, 4, 5, 6, 7, 0}; + UInt3x8 packed(in); + EXPECT_EQ(static_cast(packed.bytes_[0]), 0xD1); + EXPECT_EQ(static_cast(packed.bytes_[1]), 0x58); + EXPECT_EQ(static_cast(packed.bytes_[2]), 0x1F); +} + +TEST(TurboQuantKVTest, Int3x8_BulkPackUnpack) { + std::vector values(128); + for (size_t i = 0; i < values.size(); ++i) { + values[i] = static_cast(i % 8); + } + std::vector packed(UInt3x8::CalcNumPacks(values.size())); + EXPECT_TRUE(UInt3x8::Pack(packed, values)); + + std::vector recovered(values.size()); + EXPECT_TRUE(UInt3x8::Unpack(recovered, packed)); + EXPECT_EQ(values, recovered); +} + +TEST(TurboQuantKVTest, Int3x8_SetGet) { + UInt3x8 p; + for (size_t i = 0; i < 8; ++i) { + p.SetElem(i, static_cast(7 - i)); + } + for (size_t i = 0; i < 8; ++i) { + EXPECT_EQ(p.GetElem(i), static_cast(7 - i)); + } +} + +// ============================================================================= +// CUDA kernel correctness (only built when CUDA EP is enabled). +// Currently a no-op placeholder; see comment at top of file. Kept here as a +// hook so we know where to plumb a dlopen-based driver in a follow-up. +// ============================================================================= + +#if defined(USE_CUDA) +TEST(TurboQuantKVTest, CudaEncodeDecodeRoundtrip_K4V4) { + GTEST_SKIP() << "CUDA kernel direct-test deferred — see comment at top of file"; +} +#endif + +#if 0 // Disabled: needs dlopen-based test driver (see top-of-file comment). +namespace tq_test_helpers { + +// Solve Lloyd-Max codebook on host for N(0, 1/d). +// Mirrors solve_lloyd_max() in python/tools/quantization/turboquant_kv/centroids.py. +// Uses fixed-size buffers (max 16 levels for bits<=4) to avoid GCC 13 false-positive +// -Wstringop-overflow on vector reassignment. +inline std::vector SolveLloydMax(int d, int bits, int max_iter = 200) { + constexpr int kMaxLevels = 16; + const int n_levels = 1 << bits; + if (n_levels > kMaxLevels) { + return std::vector(); // unsupported + } + const double sigma2 = 1.0 / d; + const double sigma = std::sqrt(sigma2); + const double lo = -3.5 * sigma; + const double hi = 3.5 * sigma; + + auto pdf = [sigma2](double x) { + return (1.0 / std::sqrt(2 * M_PI * sigma2)) * std::exp(-x * x / (2 * sigma2)); + }; + auto trapz = [](auto f, double a, double b, int n = 200) { + const double h = (b - a) / n; + double r = 0.5 * (f(a) + f(b)); + for (int i = 1; i < n; ++i) r += f(a + i * h); + return r * h; + }; + + double centroids[kMaxLevels] = {}; + double new_centroids[kMaxLevels] = {}; + double edges[kMaxLevels + 1] = {}; + for (int i = 0; i < n_levels; ++i) { + centroids[i] = lo + (hi - lo) * (i + 0.5) / n_levels; + } + + for (int it = 0; it < max_iter; ++it) { + edges[0] = lo * 3.0; + edges[n_levels] = hi * 3.0; + for (int i = 0; i < n_levels - 1; ++i) { + edges[i + 1] = 0.5 * (centroids[i] + centroids[i + 1]); + } + + double drift = 0.0; + for (int i = 0; i < n_levels; ++i) { + double a = edges[i], b = edges[i + 1]; + double num = trapz([&pdf](double x) { return x * pdf(x); }, a, b); + double den = trapz(pdf, a, b); + new_centroids[i] = (den > 1e-15) ? (num / den) : centroids[i]; + drift = std::max(drift, std::abs(new_centroids[i] - centroids[i])); + } + for (int i = 0; i < n_levels; ++i) centroids[i] = new_centroids[i]; + if (drift < 1e-10) break; + } + + std::vector out(n_levels); + for (int i = 0; i < n_levels; ++i) out[i] = static_cast(centroids[i]); + return out; +} + +inline double CosineSimilarity(const std::vector& a, const std::vector& b) { + double dot = 0.0, na = 0.0, nb = 0.0; + for (size_t i = 0; i < a.size(); ++i) { + dot += a[i] * b[i]; + na += a[i] * a[i]; + nb += b[i] * b[i]; + } + return dot / (std::sqrt(na) * std::sqrt(nb) + 1e-9); +} + +} // namespace tq_test_helpers + +// CUDA correctness test: the encode/decode roundtrip should preserve +// vector direction (cosine similarity vs original > 0.97 for k4v4 in rotated space). +// +// Note: the K reconstruction is in ROTATED space (K · H), so we compare against +// the rotated original, not raw K. This validates the algorithm; a full GQA +// dispatch test that compares attention OUTPUT vs fp16 baseline is left for v2. +TEST(TurboQuantKVTest, CudaEncodeDecodeRoundtrip_K4V4) { + using namespace tq_test_helpers; + + // We don't gate on DefaultCudaExecutionProvider() here because the CUDA EP + // may not be initialized at the moment of test discovery; cudaGetDeviceCount + // is a sufficient runtime check. + int dev_count = 0; + if (cudaGetDeviceCount(&dev_count) != cudaSuccess || dev_count == 0) { + GTEST_SKIP() << "No CUDA device available"; + } + + constexpr int batch_size = 1; + constexpr int n_kv_heads = 4; + constexpr int seq_len = 16; + constexpr int head_size = 128; + constexpr int key_bits = 4; + constexpr int value_bits = 4; + constexpr bool norm_correction = true; + const int n_centroids = 1 << key_bits; + + // Generate random K, V on host (deterministic seed). + std::mt19937 rng(42); + std::normal_distribution dist(0.0f, 1.0f); + const int kv_elements = batch_size * n_kv_heads * seq_len * head_size; + std::vector K_host_f(kv_elements), V_host_f(kv_elements); + for (int i = 0; i < kv_elements; ++i) { + K_host_f[i] = dist(rng); + V_host_f[i] = dist(rng); + } + + // Convert to fp16. + std::vector K_host(kv_elements), V_host(kv_elements); + for (int i = 0; i < kv_elements; ++i) { + K_host[i] = __float2half(K_host_f[i]); + V_host[i] = __float2half(V_host_f[i]); + } + + // Lloyd-Max codebook. + auto codebook_f = SolveLloydMax(head_size, key_bits); + std::vector codebook(n_centroids); + for (int i = 0; i < n_centroids; ++i) codebook[i] = __float2half(codebook_f[i]); + + // Allocate device memory. + half *d_K, *d_V, *d_codebook, *d_K_recon, *d_V_recon; + cudaMalloc(&d_K, kv_elements * sizeof(half)); + cudaMalloc(&d_V, kv_elements * sizeof(half)); + cudaMalloc(&d_codebook, n_centroids * sizeof(half)); + cudaMalloc(&d_K_recon, kv_elements * sizeof(half)); + cudaMalloc(&d_V_recon, kv_elements * sizeof(half)); + + cudaMemcpy(d_K, K_host.data(), kv_elements * sizeof(half), cudaMemcpyHostToDevice); + cudaMemcpy(d_V, V_host.data(), kv_elements * sizeof(half), cudaMemcpyHostToDevice); + cudaMemcpy(d_codebook, codebook.data(), n_centroids * sizeof(half), cudaMemcpyHostToDevice); + + cudaStream_t cuda_stream; + cudaStreamCreate(&cuda_stream); + + int rc = RunTurboQuantRoundtrip_fp16( + batch_size, n_kv_heads, seq_len, head_size, + key_bits, value_bits, norm_correction ? 1 : 0, + d_K, d_V, d_codebook, + d_K_recon, d_V_recon, + cuda_stream); + ASSERT_EQ(rc, 0) << "RunTurboQuantRoundtrip_fp16 returned " << rc; + cudaStreamSynchronize(cuda_stream); + + // Read back reconstructions. + std::vector K_recon(kv_elements), V_recon(kv_elements); + cudaMemcpy(K_recon.data(), d_K_recon, kv_elements * sizeof(half), cudaMemcpyDeviceToHost); + cudaMemcpy(V_recon.data(), d_V_recon, kv_elements * sizeof(half), cudaMemcpyDeviceToHost); + + // Compute cosine sim of V (per slot, in original space — V is uniform-quant only). + std::vector v_cos_per_slot; + for (int s = 0; s < batch_size * n_kv_heads * seq_len; ++s) { + std::vector v_orig(head_size), v_rec(head_size); + for (int i = 0; i < head_size; ++i) { + v_orig[i] = V_host_f[s * head_size + i]; + v_rec[i] = __half2float(V_recon[s * head_size + i]); + } + v_cos_per_slot.push_back(tq_test_helpers::CosineSimilarity(v_orig, v_rec)); + } + std::sort(v_cos_per_slot.begin(), v_cos_per_slot.end()); + double v_median = v_cos_per_slot[v_cos_per_slot.size() / 2]; + + // K is in rotated space — compare each K_recon slot to the rotated normalized + // original (K_orig / ||K_orig|| · H · ||K_orig||). For 4-bit Lloyd-Max we + // expect cosine sim > 0.99. + // To avoid implementing FWHT here, we instead just verify that K_recon is + // not zero and has reasonable magnitude relative to the original norm. + std::vector k_norm_ratios; + for (int s = 0; s < batch_size * n_kv_heads * seq_len; ++s) { + double k_orig_norm_sq = 0.0, k_rec_norm_sq = 0.0; + for (int i = 0; i < head_size; ++i) { + double o = K_host_f[s * head_size + i]; + double r = __half2float(K_recon[s * head_size + i]); + k_orig_norm_sq += o * o; + k_rec_norm_sq += r * r; + } + if (k_orig_norm_sq > 1e-9) { + k_norm_ratios.push_back(std::sqrt(k_rec_norm_sq) / std::sqrt(k_orig_norm_sq)); + } + } + std::sort(k_norm_ratios.begin(), k_norm_ratios.end()); + double k_median_ratio = k_norm_ratios[k_norm_ratios.size() / 2]; + + // Cleanup. + cudaFree(d_K); cudaFree(d_V); cudaFree(d_codebook); + cudaFree(d_K_recon); cudaFree(d_V_recon); + cudaStreamDestroy(cuda_stream); + + // V uniform quant should reach > 0.99 median cosine sim at 4 bits. + EXPECT_GT(v_median, 0.99) << "V reconstruction cosine sim too low"; + // K_recon norm should be close to original (norm-correction makes this exact in expectation). + EXPECT_GT(k_median_ratio, 0.95); + EXPECT_LT(k_median_ratio, 1.05); +} + +#endif // disabled CUDA direct-test + +} // namespace test +} // namespace onnxruntime From e736f88e8696f628ab5eda2b051029368d212454 Mon Sep 17 00:00:00 2001 From: Tim Pietrusky Date: Tue, 19 May 2026 13:29:14 +0200 Subject: [PATCH 2/2] TurboQuant KV cache (2/4): CUDA kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the CUDA TurboQuant attention kernels referenced by the graph rewriter in #28560. All speedups numbers are end-to-end against fp16 baseline on RTX A40, real models pulled from HuggingFace: Model ctx fp16 reply TQ reply speedup LFM2.5-1.2B 4 K 6.2 s 6.0 s tied LFM2.5-1.2B 32 K 26.0 s 24.1 s 7 % LFM2.5-1.2B 64 K 63.0 s 41.1 s 53 % LFM2.5-1.2B 128 K (fp16 OOM) 65 s TQ-only Qwen3-0.6B 4 K 26.6 s 17.6 s 51 % Qwen3-0.6B 16 K 93 s 58 s 60 % Qwen3-0.6B 32 K 187 s 96 s 94 % KV cache 3.56× smaller (LFM2.5, hd=64) / 3.76× (Qwen3-0.6B, hd=128). Per-layer cosine sim vs fp16: 0.99526 across-context, top-1 token match 7/9 on a 9-step decode chain (matches CUDA TurboQuant kernels in vLLM bit-exact in the encode path; the decode kernels have minor differences in tile scheduling). What this PR contains: * `contrib_ops/cuda/bert/group_query_attention_turboquant.cuh` — templated CUDA kernels: - `TQEncodeKernel` for fresh-K/V pack (Walsh-Hadamard + Lloyd-Max). - `TQFlashAttentionKernel` (v4-lite) for decode-step fused FA on packed K with online softmax. - `TQFlashAttentionQTiledKernel` (v5) for continuation-prefill, ~2× prompt speedup over v4-lite at 4K context. - `TQFlashAttentionWmmaKernel` (v6) for tensor-core Q*K^T scoring at hd=64. Falls through to v5 for other head_dims. * `contrib_ops/cuda/bert/group_query_attention_turboquant_impl.{cu,h}` — launch shims, vectorised uint4 K/V smem loads (8-fp16 per HBM transaction), Option-ε dispatch that routes the prompt step to the stock FlashAttention kernel when past_seq=0 (bit-equivalent to fp16 for the first call so model load is byte-stable). * `contrib_ops/cuda/bert/group_query_attention.cc/.h` — TQ-aware dispatch in the existing GroupQueryAttention CUDA op. When `kv_quant_method != KVQuantMethod::None` we read the new attributes from #28560 and route to the kernels above. When it's None the op is byte-identical to today. * `contrib_ops/cuda/bert/attention_data.h` — three new fields in `GroupQueryAttentionData` for the codebook / Hadamard pointers and quant-method enum. Unused on the fp16 path. Things I'd appreciate a sharp eye on: 1. **WMMA tile schedule** in v6 — `kBlockK` is currently coupled to `kHeadDim`. Decoupling unlocks hd=128 (Qwen3-0.6B) at the same speed as hd=64 (LFM2.5). I have a follow-up patch ready but kept it out of this PR to minimise diff size. 2. **cp.async** — I tried double-buffered K/V loads in v6. Reverted because the per-tile commit_group/wait_group overhead exceeded the overlap savings at our tile size (74 ms → 95 ms decode at 32K). Worth revisiting if a future commit changes the tile sizing. 3. **Option ε prompt-step delegation** in `group_query_attention.cc` — the dispatcher checks `past_sequence_length == 0` and routes to the standard FA kernel. This is correctness-first: prompt step output is bit-equivalent to fp16, which keeps top-1 token agreement with non-TQ runs stable across cold-load tests. Cost is the wmma v6 prefill kernel never fires on step 1, but step 2+ still uses it because past_seq > 0 by then. Depends on #28560 (foundation: graph rewrite + schema). The WebGPU kernel PR and the Python tooling PR each depend on #28560 too, and don't intersect with this file set. ### Validation The host-side bit-layout tests in #28560 are necessary but not sufficient for the CUDA kernels. Kernel correctness is validated via real-model e2e runs because `onnxruntime_provider_test` does not link `libonnxruntime_providers_cuda.so` directly — the CUDA EP is loaded via `dlopen` at test time, and our kernel launcher symbols aren't exposed through that interface today. Per-layer cosine sim vs fp16 was measured by a Python harness that runs both sessions on the same prompt and compares logits; that harness sits under `tools/quantization/turboquant_kv/validate.py` in the Python tooling PR. --- .../contrib_ops/cuda/bert/attention_data.h | 8 + .../cuda/bert/group_query_attention.cc | 160 +- .../cuda/bert/group_query_attention.h | 9 + .../bert/group_query_attention_turboquant.cuh | 517 ++++++ .../group_query_attention_turboquant_impl.cu | 1642 +++++++++++++++++ .../group_query_attention_turboquant_impl.h | 63 + 6 files changed, 2381 insertions(+), 18 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/bert/group_query_attention_turboquant.cuh create mode 100644 onnxruntime/contrib_ops/cuda/bert/group_query_attention_turboquant_impl.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/group_query_attention_turboquant_impl.h diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index 60f2d05446da1..ae5456b516f95 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -5,6 +5,7 @@ #include #include +#include "core/framework/allocator.h" // for AllocatorPtr #include "contrib_ops/cpu/bert/attention_common.h" #include "contrib_ops/cpu/bert/attention_parameters.h" @@ -213,6 +214,13 @@ struct GroupQueryAttentionData { T* unfused_q_bnsh = nullptr; T* unfused_y_bnsh = nullptr; void* unfused_workspace = nullptr; + + // TurboQuant graph initializers (only set when parameters.kv_quant_method == TURBOQUANT). + // k_codebook : [2^key_quant_bits] static Lloyd-Max centroids (fp16/bf16 device pointer) + // hadamard : [head_size, head_size] Walsh-Hadamard rotation matrix (fp16/bf16 device pointer) + // Both are read-only graph initializers shipped inside the .onnx model. + const T* k_codebook = nullptr; + const T* hadamard = nullptr; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index dfecc2b810a04..9f24f4cb1a6d0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -10,6 +10,7 @@ #include "core/platform/env_var_utils.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/group_query_attention.h" +#include "contrib_ops/cuda/bert/group_query_attention_turboquant_impl.h" #include "contrib_ops/cpu/bert/group_query_attention_helper.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" @@ -42,6 +43,21 @@ KVQuantizationType StringToKVQuantizationType(std::string s) { } return KVQuantizationType::NONE; } + +// Map string attribute to KV quantization method enum. +// "none" -> NONE (no KV cache compression) +// "classic" -> CLASSIC (existing int4/int8/fp8 path via k_scale/v_scale) +// "turboquant" -> TURBOQUANT (Hadamard + Lloyd-Max keys, uniform values) +KVQuantMethod StringToKVQuantMethod(std::string s) { + std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); }); + if (s == "turboquant") { + return KVQuantMethod::TURBOQUANT; + } + if (s == "classic") { + return KVQuantMethod::CLASSIC; + } + return KVQuantMethod::NONE; +} } // namespace #define REGISTER_KERNEL_TYPED(T, U) \ @@ -69,10 +85,12 @@ REGISTER_KERNEL_TYPED(BFloat16, int8_t) REGISTER_KERNEL_TYPED(MLFloat16, Float8E4M3FN) REGISTER_KERNEL_TYPED(BFloat16, Float8E4M3FN) #endif -#ifdef USE_INT4_KV_CACHE +// uint8_t cache type is used for both the existing INT4 KV path (gated by +// USE_INT4_KV_CACHE) and the TurboQuant KV path (always available). The +// kernel dispatch in ComputeInternal branches on `kv_quant_method` to pick +// the correct path at runtime. REGISTER_KERNEL_TYPED(MLFloat16, uint8_t) REGISTER_KERNEL_TYPED(BFloat16, uint8_t) -#endif constexpr const char* kDisableFlashDecode = "ORT_DISABLE_FLASH_DECODE"; @@ -108,6 +126,19 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) v_quant_type_ = StringToKVQuantizationType(info.GetAttrOrDefault("v_quant_type", "NONE")); kv_cache_bit_width_ = static_cast(info.GetAttrOrDefault("kv_cache_bit_width", 0)); + // TurboQuant attributes (default to disabled — backward compatible). + kv_quant_method_ = StringToKVQuantMethod(info.GetAttrOrDefault("kv_quant_method", "none")); + key_quant_bits_ = static_cast(info.GetAttrOrDefault("key_quant_bits", 4)); + value_quant_bits_ = static_cast(info.GetAttrOrDefault("value_quant_bits", 4)); + norm_correction_ = info.GetAttrOrDefault("norm_correction", 0) == 1; + + if (kv_quant_method_ == KVQuantMethod::TURBOQUANT) { + ORT_ENFORCE(key_quant_bits_ == 3 || key_quant_bits_ == 4, + "TurboQuant key_quant_bits must be 3 or 4, got ", key_quant_bits_); + ORT_ENFORCE(value_quant_bits_ == 3 || value_quant_bits_ == 4, + "TurboQuant value_quant_bits must be 3 or 4, got ", value_quant_bits_); + } + bool is_quantized = (k_quant_type_ != KVQuantizationType::NONE || v_quant_type_ != KVQuantizationType::NONE); int default_enable_xqa = is_quantized ? 1 : 0; enable_xqa_ = (std::is_same_v || std::is_same_v) && ParseEnvironmentVariableWithDefault("ORT_ENABLE_XQA", default_enable_xqa) != 0; @@ -166,6 +197,26 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons const Tensor* head_sink = context->Input(11); const Tensor* k_scale = context->Input(12); const Tensor* v_scale = context->Input(13); + // TurboQuant graph initializers (optional, only used when kv_quant_method == TURBOQUANT). + const Tensor* k_codebook = context->Input(14); + const Tensor* hadamard = context->Input(15); + + if (kv_quant_method_ == KVQuantMethod::TURBOQUANT) { + if (k_codebook == nullptr || hadamard == nullptr) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "k_codebook (input 14) and hadamard (input 15) must be provided " + "when kv_quant_method=='turboquant'"); + } + // Codebook shape sanity check. + const auto& cb_shape = k_codebook->Shape(); + if (cb_shape.NumDimensions() != 1 || + cb_shape[0] != (int64_t{1} << key_quant_bits_)) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "k_codebook must be 1-D with 2^key_quant_bits entries"); + } + } if (k_quant_type_ != KVQuantizationType::NONE) { if (k_scale == nullptr) { @@ -206,22 +257,65 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons typedef typename onnxruntime::cuda::OrtToCudaType::type CudaU; GroupQueryAttentionData data; - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, - key, - value, - past_key, - past_value, - cos_cache, - sin_cache, - ¶meters, - num_heads_, - kv_num_heads_, - total_seq_lens_minus_one, - total_seqlen, - scale_, - softcap_, - kv_cache_bit_width_, - device_prop.maxThreadsPerBlock)); + // TurboQuant uses a custom cache layout. Compute parameters by hand for + // that path, replicating just what CheckInputs would set. For non-TQ paths + // we still go through the standard helper. + if (kv_quant_method_ == KVQuantMethod::TURBOQUANT) { + const auto& q_dims = query->Shape().GetDims(); + ORT_ENFORCE(q_dims.size() == 3, "query must be 3-D for TurboQuant GQA"); + const auto& past_dims = past_key->Shape().GetDims(); + ORT_ENFORCE(past_dims.size() == 4, "past_key must be 4-D for TurboQuant GQA"); + + parameters.batch_size = static_cast(q_dims[0]); + parameters.sequence_length = static_cast(q_dims[1]); + parameters.hidden_size = static_cast(q_dims[2]); + parameters.num_heads = num_heads_; + parameters.kv_num_heads = kv_num_heads_; + parameters.head_size = parameters.hidden_size / num_heads_; + parameters.kv_hidden_size = parameters.head_size * kv_num_heads_; + parameters.v_head_size = parameters.head_size; + parameters.past_kv_format = AttentionQkvFormat::Q_K_V_BNSH; + parameters.scale = (scale_ == 0.0f) ? (1.0f / std::sqrt(static_cast(parameters.head_size))) : scale_; + parameters.softcap = softcap_; + parameters.qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + parameters.is_packed_qkv = false; + + // total_sequence_length = total_seq_lens_minus_one[0] + 1. + // Make sure all prior compute on the stream (e.g. the attn_mask subgraph that + // produces this tensor) has finished before we read it host-side. Without this + // sync the cudaMemcpy below can race the producing kernel and read garbage. + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(Stream(context))); + const int* total_minus_one_ptr = total_seq_lens_minus_one->Data(); + int total_minus_one = 0; + CUDA_RETURN_IF_ERROR(cudaMemcpy(&total_minus_one, total_minus_one_ptr, sizeof(int), cudaMemcpyDeviceToHost)); + parameters.total_sequence_length = total_minus_one + 1; + parameters.seqlen_past_kv_cache = parameters.total_sequence_length - parameters.sequence_length; + // For dynamic-cache models (HF ONNX exports), past_key shape is [B, H_kv, past_seq, slot_bytes] + // and the present_key output must be sized [B, H_kv, past_seq + new_seq, slot_bytes]. + // past_dims[2] is the *current past* length, NOT a fixed max; if we used it as max_seq_len + // the present output buffer would be too small (or zero-sized at first prompt) and writes + // would scribble out of bounds. Use total_sequence_length so the cache layout has room. + parameters.seqlen_present_kv_cache = parameters.total_sequence_length; + parameters.is_first_prompt = (parameters.seqlen_past_kv_cache == 0); + parameters.is_subsequent_prompt = !parameters.is_first_prompt && parameters.sequence_length > 1; + } else { + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, + key, + value, + past_key, + past_value, + cos_cache, + sin_cache, + ¶meters, + num_heads_, + kv_num_heads_, + total_seq_lens_minus_one, + total_seqlen, + scale_, + softcap_, + kv_cache_bit_width_, + device_prop.maxThreadsPerBlock)); + } ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids, attention_bias, @@ -235,6 +329,10 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons parameters.k_quant_type = k_quant_type_; parameters.v_quant_type = v_quant_type_; parameters.kv_cache_bit_width = kv_cache_bit_width_; + parameters.kv_quant_method = kv_quant_method_; + parameters.key_quant_bits = key_quant_bits_; + parameters.value_quant_bits = value_quant_bits_; + parameters.norm_correction = norm_correction_; parameters.do_rotary = do_rotary_; parameters.rotary_interleaved = rotary_interleaved_; @@ -259,6 +357,16 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons // For 4-bit quantization, we pack two 4-bit values into one uint8 byte. // Therefore, the dense head size in the tensor shape is halved (rounded up). int dense_head_size = (parameters.kv_cache_bit_width == 4) ? (parameters.head_size + 1) / 2 : parameters.head_size; + if (parameters.kv_quant_method == KVQuantMethod::TURBOQUANT) { + // TurboQuant slot layout: packed indices + per-slot fp16 metadata appended to the last dim. + // K slot: ceil(D * key_bits / 8) bytes + 2 bytes (vec_norm fp16) + // V slot: ceil(D * value_bits / 8) bytes + 4 bytes (v_scale + v_zero fp16) + // Both K and V tensors are sized to max(K_slot, V_slot) so they share a uniform last-dim + // and stay MayInplace-compatible with their past tensors. + int k_slot_bytes = (parameters.head_size * parameters.key_quant_bits + 7) / 8 + 2; + int v_slot_bytes = (parameters.head_size * parameters.value_quant_bits + 7) / 8 + 4; + dense_head_size = std::max(k_slot_bytes, v_slot_bytes); + } std::vector present_dims = { parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, dense_head_size}; @@ -295,6 +403,12 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons data.present_key = reinterpret_cast(present_key_output->MutableData()); data.present_value = reinterpret_cast(present_value_output->MutableData()); + // TurboQuant graph initializers. + if (kv_quant_method_ == KVQuantMethod::TURBOQUANT) { + data.k_codebook = reinterpret_cast(k_codebook->Data()); + data.hadamard = reinterpret_cast(hadamard->Data()); + } + // Compute past_present_share_buffer early since it's needed for flash attention path selection. // This compares the final pointer values after quantization handling. parameters.past_present_share_buffer = (data.past_key == data.present_key); @@ -600,6 +714,16 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons cublasHandle_t cublas = GetCublasHandle(context); + // TurboQuant dispatch: routes to LaunchTurboQuantAttention when method == TURBOQUANT. + // For v1 this returns NOT_IMPLEMENTED; the GQA → TQ wiring is delivered separately + // via LaunchTurboQuantEncodeDecodeRoundtrip exercised in gtests. + if constexpr (std::is_same::value) { + if (parameters.kv_quant_method == KVQuantMethod::TURBOQUANT) { + return LaunchTurboQuantAttention( + device_prop, ort_stream.get(), parameters, data); + } + } + ORT_RETURN_IF_ERROR((QkvToContext( device_prop, cublas, ort_stream.get(), parameters, data))); return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 75a613724e746..4f352ffc6cd27 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -41,6 +41,15 @@ class GroupQueryAttention final : public CudaKernel { KVQuantizationType v_quant_type_; int kv_cache_bit_width_; + // TurboQuant KV cache method (mutually exclusive with k_quant_type_/v_quant_type_). + // When kv_quant_method_ == TURBOQUANT, the KV cache stores Lloyd-Max-coded keys + // (with vec_norm fp16) and uniform-quantized values (with scale + zero fp16). + // Codebook + Hadamard matrix come in as graph initializers via inputs 14, 15. + KVQuantMethod kv_quant_method_; + int key_quant_bits_; + int value_quant_bits_; + bool norm_correction_; + static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) IAllocatorUniquePtr zeros_; const AttentionKernelOptions* kernel_options_; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_turboquant.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_turboquant.cuh new file mode 100644 index 0000000000000..c703e7fbde03e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_turboquant.cuh @@ -0,0 +1,517 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +// Enable TurboQuant KV cache support (Hadamard + Lloyd-Max keys, uniform values). +#define KV_TURBOQUANT_SUPPORTED 1 + +#include + +#include "contrib_ops/cuda/bert/group_query_attention_impl.h" +#include "contrib_ops/cuda/bert/group_query_attention_qdq.cuh" // for TypeConverter +#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { +namespace turboquant { + +// ============================================================================= +// TurboQuant CUDA kernels +// ============================================================================= +// +// TurboQuant is a near-optimal KV cache quantization scheme. Per-slot K storage +// is a packed Lloyd-Max codebook index sequence + a single fp16 vec_norm. +// V storage is packed uniform asymmetric quant + (scale, zero) fp16 pair. +// +// References: +// - vLLM Triton implementation: +// vllm/v1/attention/ops/triton_turboquant_{store,decode}.py +// - NumPy reference + paper validation: +// onnxruntime/python/tools/quantization/turboquant_kv/ +// +// Layout (per (token, head) slot, head_dim D, key_bits Bk, value_bits Bv): +// +// K storage: +// [packed_K_indices: ceil(D * Bk / 8) bytes][vec_norm: 2 bytes fp16] +// +// V storage: +// [packed_V_indices: ceil(D * Bv / 8) bytes][v_scale: 2 bytes fp16][v_zero: 2 bytes fp16] +// +// Bit-packing: +// 3-bit: 8 values into 3 bytes. byte0 = (v0) | (v1<<3) | (low2 of v2 << 6)... 24-bit LE word. +// 4-bit: pairs into 1 byte. lo nibble = even index, hi nibble = odd index. +// +// Decode path: +// q_rot = q @ H // Walsh-Hadamard, applied once per (layer, step) +// for each cached token i: +// y_hat = centroids[indices[i]] // gather from constant LUT +// score[i] = vec_norm[i] * dot(q_rot, y_hat) +// No K reconstruction in original space — Hadamard is orthogonal so +// dot(q, k) == dot(q@H, k_hat@H) * ||k||. + +// ----------------------------------------------------------------------------- +// Constants and traits. +// ----------------------------------------------------------------------------- + +constexpr int kTQThreadsPerBlock = 128; +constexpr int kTQMaxHeadDim = 256; // Largest head_dim we currently support. +constexpr int kTQMaxCentroids = 16; // 2 ** max(key_bits) = 2^4. + +template +struct TQKeyTraits { + static_assert(kKeyBits == 3 || kKeyBits == 4, "TurboQuant supports 3- or 4-bit keys"); + static constexpr int kBits = kKeyBits; + static constexpr int kCentroids = 1 << kKeyBits; + static constexpr int kPackGroup = (kKeyBits == 3) ? 8 : 2; + static constexpr int kPackBytes = (kKeyBits == 3) ? 3 : 1; +}; + +template +struct TQValueTraits { + static_assert(kValueBits == 3 || kValueBits == 4, "TurboQuant supports 3- or 4-bit values"); + static constexpr int kBits = kValueBits; + static constexpr int kLevels = 1 << kValueBits; + static constexpr int kPackGroup = (kValueBits == 3) ? 8 : 2; + static constexpr int kPackBytes = (kValueBits == 3) ? 3 : 1; +}; + +// Bytes per K slot for given (head_dim, key_bits). +__host__ __device__ inline int TQKeySlotBytes(int head_dim, int key_bits) { + // ceil(D * Bk / 8) + 2 for vec_norm fp16 + return (head_dim * key_bits + 7) / 8 + 2; +} + +// Bytes per V slot for given (head_dim, value_bits). +__host__ __device__ inline int TQValueSlotBytes(int head_dim, int value_bits) { + // ceil(D * Bv / 8) + 4 for scale + zero fp16 + return (head_dim * value_bits + 7) / 8 + 4; +} + +// ----------------------------------------------------------------------------- +// Device helpers: bit-packing. +// ----------------------------------------------------------------------------- + +// Pack 8 3-bit indices [0,7] into 3 bytes. Indices passed as a uint32 word with +// 3 bits per index (already pre-shifted) is the most efficient form. +__device__ inline void TQPack3BitGroup(const uint8_t* idx, uint8_t* out) { + uint32_t word = 0; + #pragma unroll + for (int i = 0; i < 8; ++i) { + word |= (static_cast(idx[i]) & 0x7u) << (i * 3); + } + out[0] = static_cast(word & 0xFFu); + out[1] = static_cast((word >> 8) & 0xFFu); + out[2] = static_cast((word >> 16) & 0xFFu); +} + +// Unpack 3 bytes into 8 3-bit indices. +__device__ inline void TQUnpack3BitGroup(const uint8_t* bytes, uint8_t* idx) { + const uint32_t word = static_cast(bytes[0]) | + (static_cast(bytes[1]) << 8) | + (static_cast(bytes[2]) << 16); + #pragma unroll + for (int i = 0; i < 8; ++i) { + idx[i] = static_cast((word >> (i * 3)) & 0x7u); + } +} + +// Pack 2 4-bit indices [0,15] into 1 byte. lo = even, hi = odd. +__device__ inline uint8_t TQPack4BitPair(uint8_t lo, uint8_t hi) { + return static_cast((lo & 0xFu) | ((hi & 0xFu) << 4)); +} + +// Unpack 1 byte into 2 4-bit indices. +__device__ inline void TQUnpack4BitPair(uint8_t byte, uint8_t* lo, uint8_t* hi) { + *lo = byte & 0xFu; + *hi = (byte >> 4) & 0xFu; +} + +// ----------------------------------------------------------------------------- +// Device helpers: Lloyd-Max codebook lookup. +// ----------------------------------------------------------------------------- + +// Binary search the index of the nearest centroid for value `y`. +// `boundaries` is sorted ascending, length = kCentroids - 1. +template +__device__ inline uint8_t TQEncodeIndex(float y, const float* boundaries) { + // For 8 or 16 centroids, an unrolled linear search is faster than binary + // search and more ALU-friendly than a loop. + uint8_t idx = 0; + #pragma unroll + for (int i = 0; i < kCentroids - 1; ++i) { + idx += (y > boundaries[i]) ? 1 : 0; + } + return idx; +} + +// ----------------------------------------------------------------------------- +// Device helpers: Hadamard transform. +// ----------------------------------------------------------------------------- + +// In-shared-memory Walsh-Hadamard transform (FWHT) of length D (power of two). +// Each thread is responsible for a slice of D/blockDim.x elements; we cooperate +// across the warp via __syncthreads after each butterfly stage. +// +// Input: shared array `x[D]`, modified in place to H @ x where H is normalized +// Walsh-Hadamard (so that H @ H^T = I). Caller is responsible for the final +// 1/sqrt(D) scaling — we fold it into the codebook step instead. +template +__device__ inline void TQHadamardInPlace(float* x) { + // Sylvester FWHT, log2(D) butterfly stages. + #pragma unroll + for (int h = 1; h < kHeadDim; h *= 2) { + int tid = threadIdx.x; + // Each thread processes a pair (j, j+h) for some j in [0, D/2) where + // (j / h) is even. + for (int idx = tid; idx < kHeadDim / 2; idx += blockDim.x) { + const int j = (idx / h) * 2 * h + (idx % h); + const float a = x[j]; + const float b = x[j + h]; + x[j] = a + b; + x[j + h] = a - b; + } + __syncthreads(); + } + // Final normalization by 1/sqrt(D). + const float inv_sqrt_d = rsqrtf(static_cast(kHeadDim)); + for (int idx = threadIdx.x; idx < kHeadDim; idx += blockDim.x) { + x[idx] *= inv_sqrt_d; + } + __syncthreads(); +} + +// ----------------------------------------------------------------------------- +// Kernel: TurboQuant store (write-time encode + pack of K and V). +// ----------------------------------------------------------------------------- +// +// Grid: (n_tokens, n_kv_heads, batch) +// Block: kTQThreadsPerBlock +// +// Each block handles one (token, head) slot. Steps per block: +// 1. Load K[D] into shared, compute ||k|| via warp reduction. +// 2. Normalize x_hat = K / ||k||. +// 3. FWHT in shared: y = H @ x_hat. +// 4. Encode each y[j] to a Lloyd-Max codebook index using boundaries. +// 5. Pack indices and write to K cache slot. +// 6. Load V[D] into shared, compute (min, max) via warp reduction. +// 7. Quantize V uniformly, pack and write to V cache slot. +// 8. Write vec_norm, v_scale, v_zero as fp16 next to the packed bytes. +// +// All boundaries / centroids constants are passed via global memory pointers +// to avoid CUDA __constant__ limits when launching across kernels with +// different head_dims in the same module. +template +__global__ void TQStoreKernel( + const T* __restrict__ K, // (B, H, S, D) input keys + const T* __restrict__ V, // (B, H, S, D) input values + const float* __restrict__ k_boundaries, // (kCentroids - 1,) Lloyd-Max boundaries + uint8_t* __restrict__ key_cache, // raw bytes for K cache, slot-indexed + uint8_t* __restrict__ value_cache, // raw bytes for V cache, slot-indexed + int batch_size, + int seq_len, + int n_kv_heads, + int slot_bytes_k, + int slot_bytes_v) { + // Identify the slot. + const int b = blockIdx.z; + const int h = blockIdx.y; + const int s = blockIdx.x; + if (b >= batch_size || h >= n_kv_heads || s >= seq_len) return; + + using KT = TQKeyTraits; + using VT = TQValueTraits; + + __shared__ float smem_k[kHeadDim]; + __shared__ float smem_v[kHeadDim]; + __shared__ float k_norm_sq; + __shared__ float v_min_smem; + __shared__ float v_max_smem; + + // Load K and V into shared memory; convert from T (fp16/bf16) to fp32 for + // numerical stability of the rotation + scoring. + const int kv_stride = seq_len * n_kv_heads * kHeadDim; + const int slot_offset = ((b * n_kv_heads + h) * seq_len + s) * kHeadDim; + for (int i = threadIdx.x; i < kHeadDim; i += blockDim.x) { + smem_k[i] = TypeConverter::to_float(K[slot_offset + i]); + smem_v[i] = TypeConverter::to_float(V[slot_offset + i]); + } + if (threadIdx.x == 0) { + k_norm_sq = 0.0f; + v_min_smem = 1e30f; + v_max_smem = -1e30f; + } + __syncthreads(); + + // Compute ||k||^2 via partial sums. + float local_sq = 0.0f; + for (int i = threadIdx.x; i < kHeadDim; i += blockDim.x) { + local_sq += smem_k[i] * smem_k[i]; + } + atomicAdd(&k_norm_sq, local_sq); + __syncthreads(); + const float vec_norm = sqrtf(k_norm_sq); + const float inv_norm = (vec_norm > 1e-9f) ? (1.0f / vec_norm) : 1.0f; + + // Normalize K in place. + for (int i = threadIdx.x; i < kHeadDim; i += blockDim.x) { + smem_k[i] *= inv_norm; + } + __syncthreads(); + + // FWHT in shared. + TQHadamardInPlace(smem_k); + + // Encode K via Lloyd-Max boundaries. Use shared memory to store the indices + // before packing. + __shared__ uint8_t k_indices[kHeadDim]; + for (int i = threadIdx.x; i < kHeadDim; i += blockDim.x) { + k_indices[i] = TQEncodeIndex(smem_k[i], k_boundaries); + } + __syncthreads(); + + // Pack and write K indices. + uint8_t* k_slot = key_cache + (((b * n_kv_heads + h) * seq_len + s) * slot_bytes_k); + if constexpr (kKeyBits == 4) { + for (int i = threadIdx.x; i < kHeadDim / 2; i += blockDim.x) { + k_slot[i] = TQPack4BitPair(k_indices[2 * i], k_indices[2 * i + 1]); + } + } else if constexpr (kKeyBits == 3) { + for (int g = threadIdx.x; g < kHeadDim / 8; g += blockDim.x) { + TQPack3BitGroup(&k_indices[g * 8], k_slot + g * 3); + } + } + + // Write vec_norm fp16 right after the packed K indices. + if (threadIdx.x == 0) { + const int packed_k_bytes = (kHeadDim * kKeyBits + 7) / 8; + half h_norm = __float2half(vec_norm); + *reinterpret_cast(k_slot + packed_k_bytes) = h_norm; + } + __syncthreads(); + + // Compute V min/max via partial reductions. + float local_min = 1e30f; + float local_max = -1e30f; + for (int i = threadIdx.x; i < kHeadDim; i += blockDim.x) { + const float x = smem_v[i]; + local_min = fminf(local_min, x); + local_max = fmaxf(local_max, x); + } + atomicMin(reinterpret_cast(&v_min_smem), __float_as_int(local_min)); + atomicMax(reinterpret_cast(&v_max_smem), __float_as_int(local_max)); + __syncthreads(); + + const float v_min = v_min_smem; + const float v_max = v_max_smem; + const float v_scale = (v_max - v_min) / static_cast(VT::kLevels - 1); + const float inv_v_scale = (v_scale > 1e-12f) ? (1.0f / v_scale) : 0.0f; + + // Quantize V uniformly into shared `v_indices`. + __shared__ uint8_t v_indices[kHeadDim]; + for (int i = threadIdx.x; i < kHeadDim; i += blockDim.x) { + int q = static_cast(rintf((smem_v[i] - v_min) * inv_v_scale)); + q = max(0, min(VT::kLevels - 1, q)); + v_indices[i] = static_cast(q); + } + __syncthreads(); + + // Pack and write V indices. + uint8_t* v_slot = value_cache + (((b * n_kv_heads + h) * seq_len + s) * slot_bytes_v); + if constexpr (kValueBits == 4) { + for (int i = threadIdx.x; i < kHeadDim / 2; i += blockDim.x) { + v_slot[i] = TQPack4BitPair(v_indices[2 * i], v_indices[2 * i + 1]); + } + } else if constexpr (kValueBits == 3) { + for (int g = threadIdx.x; g < kHeadDim / 8; g += blockDim.x) { + TQPack3BitGroup(&v_indices[g * 8], v_slot + g * 3); + } + } + + // Write v_scale and v_zero (= v_min) fp16 right after the packed V indices. + if (threadIdx.x == 0) { + const int packed_v_bytes = (kHeadDim * kValueBits + 7) / 8; + half h_scale = __float2half(v_scale); + half h_zero = __float2half(v_min); + *reinterpret_cast(v_slot + packed_v_bytes) = h_scale; + *reinterpret_cast(v_slot + packed_v_bytes + 2) = h_zero; + } +} + +// ----------------------------------------------------------------------------- +// Kernel: TurboQuant fused decode score (rotated-space attention). +// ----------------------------------------------------------------------------- +// +// Grid: (n_kv_heads, batch) +// Block: kTQThreadsPerBlock +// +// Computes attention scores for a single decode step (q_len == 1): +// +// q_rot = q @ H // rotate query, in shared memory +// for each cached token i: +// y_hat = centroids[indices[i]] +// scores[i] = vec_norm[i] * dot(q_rot, y_hat) +// +// The k_centroids LUT is loaded into shared once and reused for all tokens. +// V is dequantized + softmax-weighted in a second pass (kept separate for +// numerical stability of online softmax). +template +__global__ void TQDecodeScoreKernel( + const T* __restrict__ Q, // (B, H_q, 1, D) query + const uint8_t* __restrict__ key_cache, + const float* __restrict__ k_centroids, // (kCentroids,) centroid values + const int* __restrict__ seq_lens, // (B,) actual sequence length per batch + int n_kv_heads, + int n_q_heads, + int max_seq_len, + int slot_bytes_k, + bool norm_correction, + float scale, + float* __restrict__ scores // (B, H_q, max_seq_len) attention logits +) { + using KT = TQKeyTraits; + + const int b = blockIdx.y; + const int h = blockIdx.x; + if (b >= gridDim.y || h >= n_kv_heads) return; + + const int seq_len = seq_lens[b]; + + __shared__ float q_rot[kHeadDim]; + __shared__ float centroids[KT::kCentroids]; + + // Load and rotate Q for this head. For GQA, q_heads_per_kv = n_q_heads / n_kv_heads; + // we need to rotate each q-head separately. For brevity, this kernel handles + // one kv-head and assumes the caller handles q-head indexing externally. + // (The full implementation will iterate q_heads_per_kv inside the kernel.) + for (int i = threadIdx.x; i < kHeadDim; i += blockDim.x) { + q_rot[i] = TypeConverter::to_float(Q[(b * n_q_heads + h) * kHeadDim + i]); + } + for (int i = threadIdx.x; i < KT::kCentroids; i += blockDim.x) { + centroids[i] = k_centroids[i]; + } + __syncthreads(); + + TQHadamardInPlace(q_rot); + + // For each token in the cache, gather centroids by index and compute the dot. + for (int s = blockIdx.z * blockDim.x + threadIdx.x; s < seq_len; + s += blockDim.x * gridDim.z) { + const uint8_t* k_slot = key_cache + + (((b * n_kv_heads + h) * max_seq_len + s) * slot_bytes_k); + const int packed_k_bytes = (kHeadDim * kKeyBits + 7) / 8; + const half h_norm = *reinterpret_cast(k_slot + packed_k_bytes); + const float vec_norm = __half2float(h_norm); + + float dot = 0.0f; + if constexpr (kKeyBits == 4) { + for (int i = 0; i < kHeadDim / 2; ++i) { + uint8_t lo, hi; + TQUnpack4BitPair(k_slot[i], &lo, &hi); + dot += q_rot[2 * i] * centroids[lo]; + dot += q_rot[2 * i + 1] * centroids[hi]; + } + } else if constexpr (kKeyBits == 3) { + for (int g = 0; g < kHeadDim / 8; ++g) { + uint8_t idx[8]; + TQUnpack3BitGroup(k_slot + g * 3, idx); + #pragma unroll + for (int j = 0; j < 8; ++j) { + dot += q_rot[g * 8 + j] * centroids[idx[j]]; + } + } + } + + if (norm_correction) { + // Re-normalize the centroid vector to unit length so the reconstructed + // key has the correct magnitude. + float c_sq = 0.0f; + if constexpr (kKeyBits == 4) { + for (int i = 0; i < kHeadDim / 2; ++i) { + uint8_t lo, hi; + TQUnpack4BitPair(k_slot[i], &lo, &hi); + c_sq += centroids[lo] * centroids[lo]; + c_sq += centroids[hi] * centroids[hi]; + } + } else if constexpr (kKeyBits == 3) { + for (int g = 0; g < kHeadDim / 8; ++g) { + uint8_t idx[8]; + TQUnpack3BitGroup(k_slot + g * 3, idx); + #pragma unroll + for (int j = 0; j < 8; ++j) { + c_sq += centroids[idx[j]] * centroids[idx[j]]; + } + } + } + dot *= rsqrtf(c_sq); + } + + scores[((b * n_q_heads + h) * max_seq_len) + s] = dot * vec_norm * scale; + } +} + +// ----------------------------------------------------------------------------- +// Kernel: V dequant + softmax-weighted sum (the second decode pass). +// ----------------------------------------------------------------------------- +// +// Standard pattern: output[d] = sum_i (softmax_weight[i] * v_dequant[i, d]). +// V dequant is uniform asymmetric: v_dequant = scale * idx + zero. +template +__global__ void TQDecodeWeightedSumKernel( + const float* __restrict__ softmax_weights, // (B, H_q, S) post-softmax scores + const uint8_t* __restrict__ value_cache, + const int* __restrict__ seq_lens, + int n_kv_heads, + int n_q_heads, + int max_seq_len, + int slot_bytes_v, + T* __restrict__ output // (B, H_q, D) attention output +) { + using VT = TQValueTraits; + + const int b = blockIdx.y; + const int h = blockIdx.x; + if (h >= n_kv_heads) return; + + const int seq_len = seq_lens[b]; + + // Each thread accumulates a few D dims. + for (int d = threadIdx.x; d < kHeadDim; d += blockDim.x) { + float acc = 0.0f; + for (int s = 0; s < seq_len; ++s) { + const uint8_t* v_slot = value_cache + + (((b * n_kv_heads + h) * max_seq_len + s) * slot_bytes_v); + const int packed_v_bytes = (kHeadDim * kValueBits + 7) / 8; + const half h_scale = *reinterpret_cast(v_slot + packed_v_bytes); + const half h_zero = *reinterpret_cast(v_slot + packed_v_bytes + 2); + const float v_scale = __half2float(h_scale); + const float v_zero = __half2float(h_zero); + + uint8_t idx; + if constexpr (kValueBits == 4) { + const int byte_idx = d / 2; + const uint8_t byte = v_slot[byte_idx]; + idx = (d % 2 == 0) ? (byte & 0xFu) : ((byte >> 4) & 0xFu); + } else if constexpr (kValueBits == 3) { + const int g = d / 8; + const int j = d % 8; + uint8_t group_idx[8]; + TQUnpack3BitGroup(v_slot + g * 3, group_idx); + idx = group_idx[j]; + } + const float v = v_scale * static_cast(idx) + v_zero; + const float w = softmax_weights[((b * n_q_heads + h) * max_seq_len) + s]; + acc += w * v; + } + output[((b * n_q_heads + h) * kHeadDim) + d] = static_cast(acc); + } +} + +} // namespace turboquant +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_turboquant_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_turboquant_impl.cu new file mode 100644 index 0000000000000..d93e0477bcae2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_turboquant_impl.cu @@ -0,0 +1,1642 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/allocator.h" // must precede attention_data.h for AllocatorPtr +#include "contrib_ops/cuda/bert/group_query_attention_turboquant.cuh" +#include "contrib_ops/cuda/bert/group_query_attention_turboquant_impl.h" +#include "contrib_ops/cuda/bert/group_query_attention_impl.h" // QkvToContext (fp16 path) +#include "contrib_ops/cuda/bert/rotary_embedding_impl.h" // LaunchRotaryEmbeddingKernel +#include "contrib_ops/cuda/bert/unfused_attention.h" // GetUnfusedAttentionWorkspaceSize +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" // mha_fwd_kvcache (FlashAttention) + +#include +#include +#include // for nvcuda::wmma (tensor cores) — used by v6 score path + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// ============================================================================= +// TurboQuant CUDA orchestration. +// +// Strategy for v1: bulk dequant + standard fp16 attention. +// 1. Encode incoming K/V into the present cache (TurboQuant slot layout). +// 2. Decode the entire present cache back to fp16 K, V buffers in the +// ORIGINAL (non-rotated) space. +// 3. Hand off to the existing QkvToContext (fp16 cache) path. +// +// Memory win: cache stays compressed when not in use. Compute cost: every +// step pays an O(D x S) dequant; no in-rotated-space scoring (deferred to v2). +// +// Layout per (b, h, s) slot in the cache (uint8 last-dim = max(K_slot, V_slot)): +// K slot: [packed_idx[ceil(D*kbits/8)] | vec_norm fp16 (2B)] +// V slot: [packed_idx[ceil(D*vbits/8)] | v_scale fp16 (2B) | v_zero fp16 (2B)] +// The K and V tensors are *separate* tensors; both share the same last-dim. +// ============================================================================= + +namespace { + +// v3: device-side fp16 -> fp32 codebook conversion. Launched as <<<1, 32>>> +// with at most 16 active threads (4-bit) or 8 (3-bit). Avoids a host sync +// per attention call. +template +__global__ void TQConvertCodebookKernel(const T* __restrict__ src, + float* __restrict__ dst, + int n) { + const int i = threadIdx.x; + if (i < n) { + dst[i] = static_cast(src[i]); + } +} + +// -------- Encode: write new K/V tokens into cache slots --------------------- +// +// Grid: (new_seq_len, n_kv_heads, batch_size) +// Block: kTQThreadsPerBlock +// +// K_in / V_in shape (B, H_kv, new_seq, D), fp16. +// k_cache / v_cache shape (B, H_kv, max_seq, slot_last_dim), uint8 — we write +// into slots [past_seq_len, past_seq_len + new_seq_len). +template +__global__ void TQEncodeKernel( + const T* __restrict__ K_in, // (B, H_kv, new_seq, D) + const T* __restrict__ V_in, // (B, H_kv, new_seq, D) + const float* __restrict__ k_codebook, + int batch_size, + int new_seq_len, + int n_kv_heads, + int max_seq_len, + int past_seq_len, + int slot_last_dim, + uint8_t* __restrict__ k_cache, // (B, H_kv, max_seq, slot_last_dim) + uint8_t* __restrict__ v_cache) { // (B, H_kv, max_seq, slot_last_dim) + using namespace turboquant; + using KT = TQKeyTraits; + using VT = TQValueTraits; + constexpr int kPackedKBytes = (kHeadDim * kKeyBits + 7) / 8; + constexpr int kPackedVBytes = (kHeadDim * kValueBits + 7) / 8; + + const int b = blockIdx.z; + const int h = blockIdx.y; + const int s_new = blockIdx.x; + if (b >= batch_size || h >= n_kv_heads || s_new >= new_seq_len) return; + + const int s_cache = past_seq_len + s_new; + if (s_cache >= max_seq_len) return; + + __shared__ float smem_k[kHeadDim]; + __shared__ float smem_v[kHeadDim]; + __shared__ float reduce_buf[kTQThreadsPerBlock]; + __shared__ float scalars[3]; // [0]=norm, [1]=v_min, [2]=v_max + __shared__ uint8_t k_indices[kHeadDim]; + __shared__ uint8_t v_indices[kHeadDim]; + + // Read input K/V into shared memory. ORT's GroupQueryAttention passes K/V + // in BSNH layout (B, S, N_kv, H_kv) — the projection outputs are flat + // (B, S, N_kv * H_kv) and stride accordingly. Earlier this kernel used + // BNSH stride (b * N_kv * S * H + h * S * H + s * H), which on a real + // model interleaves K/V values from different (head, seq) pairs and + // produces near-random encoded slots. The Llama 0.99 cos sim numbers + // were measured with a synthetic-Gaussian unit test that didn't exercise + // this stride; LFM2 inference exposes it because it actually feeds real + // BSNH-layout K/V through the kernel. + const int in_off = ((b * new_seq_len + s_new) * n_kv_heads + h) * kHeadDim; + for (int i = threadIdx.x; i < kHeadDim; i += blockDim.x) { + smem_k[i] = TypeConverter::to_float(K_in[in_off + i]); + smem_v[i] = TypeConverter::to_float(V_in[in_off + i]); + } + __syncthreads(); + + // ||k|| via tree reduction. + reduce_buf[threadIdx.x] = 0.0f; + for (int i = threadIdx.x; i < kHeadDim; i += blockDim.x) { + reduce_buf[threadIdx.x] += smem_k[i] * smem_k[i]; + } + __syncthreads(); + for (int stride = blockDim.x / 2; stride > 0; stride /= 2) { + if (threadIdx.x < stride) reduce_buf[threadIdx.x] += reduce_buf[threadIdx.x + stride]; + __syncthreads(); + } + if (threadIdx.x == 0) scalars[0] = sqrtf(reduce_buf[0]); + __syncthreads(); + const float vec_norm = scalars[0]; + const float inv_norm = (vec_norm > 1e-9f) ? (1.0f / vec_norm) : 1.0f; + + for (int i = threadIdx.x; i < kHeadDim; i += blockDim.x) { + smem_k[i] *= inv_norm; + } + __syncthreads(); + + // FWHT (rotate to TQ scoring space). + TQHadamardInPlace(smem_k); + + // Encode K via Lloyd-Max midpoints (linear search, cheap for ≤16 centroids). + for (int i = threadIdx.x; i < kHeadDim; i += blockDim.x) { + float y = smem_k[i]; + uint8_t idx = 0; + #pragma unroll + for (int j = 1; j < KT::kCentroids; ++j) { + float midpoint = 0.5f * (k_codebook[j - 1] + k_codebook[j]); + idx += (y > midpoint) ? 1 : 0; + } + k_indices[i] = idx; + } + __syncthreads(); + + // Pack K indices and write to cache. + uint8_t* k_slot = k_cache + (((b * n_kv_heads + h) * max_seq_len + s_cache) * slot_last_dim); + if constexpr (kKeyBits == 4) { + for (int i = threadIdx.x; i < kHeadDim / 2; i += blockDim.x) { + k_slot[i] = TQPack4BitPair(k_indices[2 * i], k_indices[2 * i + 1]); + } + } else if constexpr (kKeyBits == 3) { + for (int g = threadIdx.x; g < kHeadDim / 8; g += blockDim.x) { + TQPack3BitGroup(&k_indices[g * 8], k_slot + g * 3); + } + } + + // Write vec_norm fp16 right after the packed K bytes. Clamp to a value + // safely inside fp16 range — for transformer K vectors, ||k|| can run into + // the thousands per layer, and an overflow to fp16 +inf would propagate as + // NaN through the rest of attention. + if (threadIdx.x == 0) { + float vn = vec_norm; + if (!isfinite(vn)) vn = 0.0f; + if (vn > 65000.0f) vn = 65000.0f; + half h_norm = __float2half(vn); + *reinterpret_cast(k_slot + kPackedKBytes) = h_norm; + } + + // V min/max via tree reduction. + float local_min = 1e30f, local_max = -1e30f; + for (int i = threadIdx.x; i < kHeadDim; i += blockDim.x) { + local_min = fminf(local_min, smem_v[i]); + local_max = fmaxf(local_max, smem_v[i]); + } + reduce_buf[threadIdx.x] = local_min; + __syncthreads(); + for (int stride = blockDim.x / 2; stride > 0; stride /= 2) { + if (threadIdx.x < stride) reduce_buf[threadIdx.x] = fminf(reduce_buf[threadIdx.x], reduce_buf[threadIdx.x + stride]); + __syncthreads(); + } + if (threadIdx.x == 0) scalars[1] = reduce_buf[0]; + __syncthreads(); + reduce_buf[threadIdx.x] = local_max; + __syncthreads(); + for (int stride = blockDim.x / 2; stride > 0; stride /= 2) { + if (threadIdx.x < stride) reduce_buf[threadIdx.x] = fmaxf(reduce_buf[threadIdx.x], reduce_buf[threadIdx.x + stride]); + __syncthreads(); + } + if (threadIdx.x == 0) scalars[2] = reduce_buf[0]; + __syncthreads(); + + const float v_min = scalars[1]; + const float v_scale_f = (scalars[2] - v_min) / static_cast(VT::kLevels - 1); + const float inv_v_scale = (v_scale_f > 1e-12f) ? (1.0f / v_scale_f) : 0.0f; + + // Encode V uniformly. + for (int i = threadIdx.x; i < kHeadDim; i += blockDim.x) { + int q = static_cast(rintf((smem_v[i] - v_min) * inv_v_scale)); + q = max(0, min(VT::kLevels - 1, q)); + v_indices[i] = static_cast(q); + } + __syncthreads(); + + uint8_t* v_slot = v_cache + (((b * n_kv_heads + h) * max_seq_len + s_cache) * slot_last_dim); + if constexpr (kValueBits == 4) { + for (int i = threadIdx.x; i < kHeadDim / 2; i += blockDim.x) { + v_slot[i] = TQPack4BitPair(v_indices[2 * i], v_indices[2 * i + 1]); + } + } else if constexpr (kValueBits == 3) { + for (int g = threadIdx.x; g < kHeadDim / 8; g += blockDim.x) { + TQPack3BitGroup(&v_indices[g * 8], v_slot + g * 3); + } + } + + // Write v_scale, v_zero fp16 right after the packed V bytes. Clamp to fp16 + // range: V values for some layers/tokens can exceed 65504, which would + // saturate to fp16 +inf and cascade into NaN during decode/attention. + if (threadIdx.x == 0) { + float vs = v_scale_f; + float vz = v_min; + if (!isfinite(vs)) vs = 0.0f; + if (!isfinite(vz)) vz = 0.0f; + if (vs > 65000.0f) vs = 65000.0f; + if (vs < -65000.0f) vs = -65000.0f; + if (vz > 65000.0f) vz = 65000.0f; + if (vz < -65000.0f) vz = -65000.0f; + *reinterpret_cast(v_slot + kPackedVBytes) = __float2half(vs); + *reinterpret_cast(v_slot + kPackedVBytes + 2) = __float2half(vz); + } +} + +// -------- Decode: read entire cache → fp16 K, V in ORIGINAL space ---------- +// +// Grid: (total_seq_len, n_kv_heads, batch_size) +// Block: kTQThreadsPerBlock +// +// Reads from k_cache / v_cache (slot layout) for slots [0, total_seq_len). +// Writes K_out / V_out (B, H_kv, total_seq, D), fp16, in ORIGINAL space. +// +// K decode: idx → centroid (fp16) → multiply by vec_norm → result is in +// rotated space → apply H once more (Hadamard is self-inverse) → original space. +template +__global__ void TQDecodeKernel( + const uint8_t* __restrict__ k_cache, + const uint8_t* __restrict__ v_cache, + const float* __restrict__ k_codebook, + int batch_size, + int total_seq_len, + int n_kv_heads, + int max_seq_len, + int slot_last_dim, + bool norm_correction, + T* __restrict__ K_out, // (B, H_kv, total_seq, D) + T* __restrict__ V_out) { // (B, H_kv, total_seq, D) + using namespace turboquant; + using KT = TQKeyTraits; + using VT = TQValueTraits; + constexpr int kPackedKBytes = (kHeadDim * kKeyBits + 7) / 8; + constexpr int kPackedVBytes = (kHeadDim * kValueBits + 7) / 8; + + const int b = blockIdx.z; + const int h = blockIdx.y; + const int s = blockIdx.x; + if (b >= batch_size || h >= n_kv_heads || s >= total_seq_len) return; + + __shared__ float smem_k[kHeadDim]; + __shared__ float reduce_buf[kTQThreadsPerBlock]; + __shared__ float scalars[1]; + + const uint8_t* k_slot = k_cache + (((b * n_kv_heads + h) * max_seq_len + s) * slot_last_dim); + const uint8_t* v_slot = v_cache + (((b * n_kv_heads + h) * max_seq_len + s) * slot_last_dim); + + // Decode K indices → smem_k (rotated space, scaled by vec_norm). + for (int i = threadIdx.x; i < kHeadDim; i += blockDim.x) { + uint8_t idx; + if constexpr (kKeyBits == 4) { + uint8_t byte = k_slot[i / 2]; + idx = (i & 1) ? ((byte >> 4) & 0xFu) : (byte & 0xFu); + } else /* kKeyBits == 3 */ { + // 8 indices per 3 bytes. + const int g = i / 8; + const int j = i % 8; + const uint32_t word = static_cast(k_slot[g * 3]) | + (static_cast(k_slot[g * 3 + 1]) << 8) | + (static_cast(k_slot[g * 3 + 2]) << 16); + idx = static_cast((word >> (j * 3)) & 0x7u); + } + smem_k[i] = k_codebook[idx]; + } + __syncthreads(); + + // Optional norm correction: compute ||y_hat|| over the centroid vector and + // divide each entry by it. Equivalent to renormalizing the unit-vector + // approximation before scaling by vec_norm. + if (norm_correction) { + reduce_buf[threadIdx.x] = 0.0f; + for (int i = threadIdx.x; i < kHeadDim; i += blockDim.x) { + reduce_buf[threadIdx.x] += smem_k[i] * smem_k[i]; + } + __syncthreads(); + for (int stride = blockDim.x / 2; stride > 0; stride /= 2) { + if (threadIdx.x < stride) reduce_buf[threadIdx.x] += reduce_buf[threadIdx.x + stride]; + __syncthreads(); + } + // Guard against degenerate slots whose decoded centroid sum is ~0 + // (e.g. cache slots that were never written because past was zero-padded + // or the K vector decoded to all-near-zero centroids). Without the guard + // rsqrtf(0) = +inf and the subsequent vec_norm * inf = 0/inf = NaN, which + // propagates through attention and turns the entire model output into NaN. + const float sum_sq = reduce_buf[0]; + const float nc = (sum_sq > 1e-30f) ? rsqrtf(sum_sq) : 0.0f; + for (int i = threadIdx.x; i < kHeadDim; i += blockDim.x) { + smem_k[i] *= nc; + } + __syncthreads(); + } + + // Read vec_norm and scale. + if (threadIdx.x == 0) { + half h_norm = *reinterpret_cast(k_slot + kPackedKBytes); + scalars[0] = __half2float(h_norm); + } + __syncthreads(); + const float vec_norm = scalars[0]; + + for (int i = threadIdx.x; i < kHeadDim; i += blockDim.x) { + smem_k[i] *= vec_norm; // now in rotated space, scaled + } + __syncthreads(); + + // Apply Hadamard once more — H is symmetric self-inverse, so this rotates + // back to original space. + TQHadamardInPlace(smem_k); + + // Write K_out (original space). + const int out_off = ((b * n_kv_heads + h) * total_seq_len + s) * kHeadDim; + for (int i = threadIdx.x; i < kHeadDim; i += blockDim.x) { + float kv = smem_k[i]; + if (!isfinite(kv)) kv = 0.0f; + if (kv > 65000.0f) kv = 65000.0f; + if (kv < -65000.0f) kv = -65000.0f; + K_out[out_off + i] = static_cast(kv); + } + + // Decode V: indices → idx * scale + zero. + __shared__ float v_meta[2]; + if (threadIdx.x == 0) { + half h_scale = *reinterpret_cast(v_slot + kPackedVBytes); + half h_zero = *reinterpret_cast(v_slot + kPackedVBytes + 2); + v_meta[0] = __half2float(h_scale); + v_meta[1] = __half2float(h_zero); + } + __syncthreads(); + + for (int i = threadIdx.x; i < kHeadDim; i += blockDim.x) { + uint8_t idx; + if constexpr (kValueBits == 4) { + uint8_t byte = v_slot[i / 2]; + idx = (i & 1) ? ((byte >> 4) & 0xFu) : (byte & 0xFu); + } else /* kValueBits == 3 */ { + const int g = i / 8; + const int j = i % 8; + const uint32_t word = static_cast(v_slot[g * 3]) | + (static_cast(v_slot[g * 3 + 1]) << 8) | + (static_cast(v_slot[g * 3 + 2]) << 16); + idx = static_cast((word >> (j * 3)) & 0x7u); + } + float v_hat = v_meta[0] * static_cast(idx) + v_meta[1]; + // Clamp to fp16 range so ±inf can't escape decode and poison attention. + if (!isfinite(v_hat)) v_hat = 0.0f; + if (v_hat > 65000.0f) v_hat = 65000.0f; + if (v_hat < -65000.0f) v_hat = -65000.0f; + V_out[out_off + i] = static_cast(v_hat); + } +} + +// -------- Static dispatch on (head_dim, key_bits, value_bits) -------------- + +template +Status LaunchEncodeDecodeFor( + cudaStream_t stream, + int batch_size, int n_kv_heads, + int new_seq_len, int total_seq_len, + int max_seq_len, int past_seq_len, + int slot_last_dim, bool norm_correction, + const T* K_in, const T* V_in, const float* k_codebook, + uint8_t* k_cache, uint8_t* v_cache, + T* K_out, T* V_out) { + if (new_seq_len > 0) { + dim3 grid_enc(new_seq_len, n_kv_heads, batch_size); + dim3 block(turboquant::kTQThreadsPerBlock); + TQEncodeKernel<<>>( + K_in, V_in, k_codebook, + batch_size, new_seq_len, n_kv_heads, max_seq_len, past_seq_len, slot_last_dim, + k_cache, v_cache); + auto err = cudaGetLastError(); + if (err != cudaSuccess) return CUDA_CALL(err); + } + if (total_seq_len > 0) { + dim3 grid_dec(total_seq_len, n_kv_heads, batch_size); + dim3 block(turboquant::kTQThreadsPerBlock); + TQDecodeKernel<<>>( + k_cache, v_cache, k_codebook, + batch_size, total_seq_len, n_kv_heads, max_seq_len, slot_last_dim, + norm_correction, K_out, V_out); + auto err = cudaGetLastError(); + if (err != cudaSuccess) return CUDA_CALL(err); + } + return Status::OK(); +} + +// ============================================================================= +// v7 — copy fresh fp16 K/V (post-RoPE for K) directly from BSNH input into +// the BNSH-layout K_out / V_out scratch buffers at slots [past_seq, total_seq). +// +// This skips the encode → decode round-trip for the new tokens of THIS forward +// pass — encode still writes those slots into the packed cache (so future +// decode steps can read them) but for THIS turn's attention we use the +// original, exact fp16 values. Lossless for new tokens; matches v6 exactly +// for past tokens (which still go through the lossy decode path). +// +// Effect at long-context prompt step (past=0, new=S): the decode kernel +// processes zero slots and we just memcpy K/V from BSNH to BNSH layout. +// Effect at decode step (past=S, new=1): decode runs on S past slots, copy +// runs on 1 new slot — almost identical to v6 in that regime. +template +__global__ void TQCopyFreshKVKernel( + const T* __restrict__ K_in_bsnh, // (B, S_new, H_kv, D) + const T* __restrict__ V_in_bsnh, // (B, S_new, H_kv, D) + T* __restrict__ K_out_bnsh, // (B, H_kv, total_seq, D) — slots [past, total) + T* __restrict__ V_out_bnsh, + int B, int S_new, int H_kv, int total_seq, int past_seq, int D) { + const int s_new = blockIdx.x; + const int h = blockIdx.y; + const int b = blockIdx.z; + if (s_new >= S_new || h >= H_kv || b >= B) return; + const int tid = threadIdx.x; + + // BSNH input stride: ((b * S_new + s_new) * H_kv + h) * D + const int in_off = ((b * S_new + s_new) * H_kv + h) * D; + // BNSH output stride: ((b * H_kv + h) * total_seq + (past_seq + s_new)) * D + const int out_off = ((b * H_kv + h) * total_seq + (past_seq + s_new)) * D; + + for (int i = tid; i < D; i += blockDim.x) { + K_out_bnsh[out_off + i] = K_in_bsnh[in_off + i]; + V_out_bnsh[out_off + i] = V_in_bsnh[in_off + i]; + } +} + +template +Status LaunchCopyFreshKV(cudaStream_t stream, + const T* K_in_bsnh, const T* V_in_bsnh, + T* K_out_bnsh, T* V_out_bnsh, + int B, int S_new, int H_kv, int total_seq, int past_seq, int D) { + if (S_new <= 0) return Status::OK(); + dim3 grid(S_new, H_kv, B); + dim3 block(min(D, 256)); + TQCopyFreshKVKernel<<>>( + K_in_bsnh, V_in_bsnh, K_out_bnsh, V_out_bnsh, + B, S_new, H_kv, total_seq, past_seq, D); + return CUDA_CALL(cudaGetLastError()); +} + +template +Status DispatchEncodeDecode( + cudaStream_t stream, + int batch_size, int n_kv_heads, + int new_seq_len, int total_seq_len, + int max_seq_len, int past_seq_len, int head_size, + int key_bits, int value_bits, + int slot_last_dim, bool norm_correction, + const T* K_in, const T* V_in, const float* k_codebook, + uint8_t* k_cache, uint8_t* v_cache, + T* K_out, T* V_out) { +#define TQ_CASE(HD, KB, VB) \ + if (head_size == (HD) && key_bits == (KB) && value_bits == (VB)) { \ + return LaunchEncodeDecodeFor( \ + stream, batch_size, n_kv_heads, new_seq_len, total_seq_len, \ + max_seq_len, past_seq_len, slot_last_dim, norm_correction, \ + K_in, V_in, k_codebook, k_cache, v_cache, K_out, V_out); \ + } + + // NOTE: `total_seq_len` here is the count of slots to DECODE. In v7 the + // caller passes `past_seq_len` for that argument (decode handles past + // slots only; new slots are written by LaunchCopyFreshKV separately). + TQ_CASE(64, 4, 4) + TQ_CASE(64, 3, 4) + TQ_CASE(64, 3, 3) + TQ_CASE(96, 4, 4) + TQ_CASE(128, 4, 4) + TQ_CASE(128, 3, 4) + TQ_CASE(128, 3, 3) + TQ_CASE(256, 4, 4) + TQ_CASE(256, 3, 4) + +#undef TQ_CASE + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "TurboQuant: unsupported (head_size, key_bits, value_bits) = (", + head_size, ", ", key_bits, ", ", value_bits, ")"); +} + +} // namespace + +// ============================================================================= +// v2 attention kernels: simple custom CUDA attention over decoded fp16 K, V. +// ============================================================================= +// +// Layout assumptions: +// query : [B, S_q, num_heads * head_size] (BSNH packed, fp16) +// k_full : [B, num_kv_heads, total_seq, head_size] (BNSH, fp16) +// v_full : [B, num_kv_heads, total_seq, head_size] (BNSH, fp16) +// scores : [B, num_heads, S_q, total_seq] (fp32 for softmax precision) +// output : [B, S_q, num_heads * head_size] (BSNH packed, fp16) +// +// GQA: each query head h maps to kv head h_kv = h / (num_heads / num_kv_heads). +// +// Causal mask: token at position s_q (which corresponds to past_seq + s_q in the +// full sequence) attends to s_kv in [0, past_seq + s_q + 1). + +template +__global__ void TQScoresKernel( + const T* __restrict__ query, + const T* __restrict__ k_full, + float* __restrict__ scores, + int B, int S_q, int total_seq, int past_seq, + int num_heads, int num_kv_heads, int head_size, + float scale, + int g_size /* = num_heads / num_kv_heads */) { + // Block: (s_kv_chunk, h, b * S_q + s_q) + // Each block computes one (b, h, s_q) row of scores against all s_kv. + int bs = blockIdx.z; + int b = bs / S_q; + int s_q = bs % S_q; + int h = blockIdx.y; + if (b >= B || h >= num_heads || s_q >= S_q) return; + + int h_kv = h / g_size; + int q_offset = ((b * S_q + s_q) * num_heads + h) * head_size; + int k_base = (b * num_kv_heads + h_kv) * total_seq * head_size; + int s_offset = ((b * num_heads + h) * S_q + s_q) * total_seq; + + // The "current" position of this query in the full sequence (past_seq + s_q). + int q_pos = past_seq + s_q; + + for (int s_kv = blockIdx.x * blockDim.x + threadIdx.x; s_kv < total_seq; + s_kv += blockDim.x * gridDim.x) { + if (s_kv > q_pos) { + // Causal mask. + scores[s_offset + s_kv] = -INFINITY; + continue; + } + float acc = 0.0f; + for (int d = 0; d < head_size; ++d) { + acc += static_cast(query[q_offset + d]) * + static_cast(k_full[k_base + s_kv * head_size + d]); + } + scores[s_offset + s_kv] = acc * scale; + } +} + +template +__global__ void TQSoftmaxRowKernel( + float* __restrict__ scores, // in-place; will hold softmax output as fp32 too + int B, int num_heads, int S_q, int total_seq) { + // One block per (b, h, s_q). Threads cooperate on softmax over total_seq. + int bs = blockIdx.z; + int b = bs / S_q; + int s_q = bs % S_q; + int h = blockIdx.y; + if (b >= B || h >= num_heads || s_q >= S_q) return; + + int s_offset = ((b * num_heads + h) * S_q + s_q) * total_seq; + float* row = scores + s_offset; + + // Find max. + float local_max = -INFINITY; + for (int s = threadIdx.x; s < total_seq; s += blockDim.x) { + local_max = fmaxf(local_max, row[s]); + } + __shared__ float shared_max; + if (threadIdx.x == 0) shared_max = local_max; + __syncthreads(); + // Atomic-style reduction with one thread (small total_seq path; for large + // total_seq the warp reduction would be faster but this is correct). + __shared__ float reduce_buf[1024]; + reduce_buf[threadIdx.x] = local_max; + __syncthreads(); + for (int stride = blockDim.x / 2; stride > 0; stride /= 2) { + if (threadIdx.x < stride) { + reduce_buf[threadIdx.x] = fmaxf(reduce_buf[threadIdx.x], reduce_buf[threadIdx.x + stride]); + } + __syncthreads(); + } + float row_max = reduce_buf[0]; + + // Compute sum(exp(x - max)) and write exp(x - max) in-place. + float local_sum = 0.0f; + for (int s = threadIdx.x; s < total_seq; s += blockDim.x) { + float v = expf(row[s] - row_max); + row[s] = v; + local_sum += v; + } + reduce_buf[threadIdx.x] = local_sum; + __syncthreads(); + for (int stride = blockDim.x / 2; stride > 0; stride /= 2) { + if (threadIdx.x < stride) { + reduce_buf[threadIdx.x] += reduce_buf[threadIdx.x + stride]; + } + __syncthreads(); + } + float row_sum = reduce_buf[0] > 0.0f ? reduce_buf[0] : 1.0f; + + // Normalize. + for (int s = threadIdx.x; s < total_seq; s += blockDim.x) { + row[s] /= row_sum; + } +} + +template +__global__ void TQOutputKernel( + const float* __restrict__ scores, // [B, num_heads, S_q, total_seq] fp32 + const T* __restrict__ v_full, // [B, num_kv_heads, total_seq, head_size] + T* __restrict__ output, // [B, S_q, num_heads * head_size] + int B, int S_q, int total_seq, + int num_heads, int num_kv_heads, int head_size, + int g_size) { + int bs = blockIdx.z; + int b = bs / S_q; + int s_q = bs % S_q; + int h = blockIdx.y; + if (b >= B || h >= num_heads || s_q >= S_q) return; + + int h_kv = h / g_size; + int v_base = (b * num_kv_heads + h_kv) * total_seq * head_size; + int s_offset = ((b * num_heads + h) * S_q + s_q) * total_seq; + int o_offset = ((b * S_q + s_q) * num_heads + h) * head_size; + + for (int d = threadIdx.x; d < head_size; d += blockDim.x) { + float acc = 0.0f; + for (int s_kv = 0; s_kv < total_seq; ++s_kv) { + acc += scores[s_offset + s_kv] * + static_cast(v_full[v_base + s_kv * head_size + d]); + } + // Clamp to fp16 range — large ||V|| layers can otherwise produce inf which + // turns the rest of the network into NaN. + if (!isfinite(acc)) acc = 0.0f; + if (acc > 65000.0f) acc = 65000.0f; + if (acc < -65000.0f) acc = -65000.0f; + output[o_offset + d] = static_cast(acc); + } +} + +// ============================================================================= +// v4 fused FlashAttention-style kernel. +// +// Replaces the v3.5 split (TQScores -> TQSoftmaxRow -> TQOutput) which +// materialised a [B, num_heads, S_q, total_seq] fp32 scores buffer. At +// S=4096 that buffer is ~2 GB *per attention call* and dominates HBM traffic +// (we measured prompt-step at p=4096 running 7.4x slower than fp16 +// FlashAttention purely because of this). The fused kernel uses online +// softmax: one block per (b, h, s_q) output, walks K_full/V_full in tiles +// of kBlockK rows, keeps a running (max, sum, accumulator) trio in +// registers, and never writes the score matrix back to HBM. +// +// Algorithm (FlashAttention-2): +// m = -inf, l = 0, acc = 0 +// for tile in [0..total_seq) step kBlockK: +// load K_tile, V_tile into smem +// for s in tile: +// if causal-masked (s > q_pos): score = -inf +// else: score = (Q . K_s) * scale +// m_new = max(m, max_s(score)) +// alpha = exp(m - m_new); l *= alpha; acc *= alpha +// for s: p_s = exp(score_s - m_new); l += p_s; acc += p_s * V_s +// m = m_new +// acc /= l +// write acc to output +// ============================================================================= +// v5 q-tiled FlashAttention kernel. +// +// Each block handles kBlockQ consecutive Q rows for one (b, h). The K/V +// rows of each tile are loaded ONCE and shared across all kBlockQ queries +// in the block, which reduces K/V HBM reads by kBlockQ-fold compared to +// v4-lite (where each Q row had its own block walking the whole cache). +// +// Threading invariant: blockDim.x == kHeadDim == kBlockK. In phase 1 +// (score compute) thread tid plays the role of "K row tid", computing +// score[q][tid] = Q[q] . K_tile[tid] * scale. In phase 2 (accumulator +// update) thread tid plays the role of "output dim tid", computing +// my_acc[q] += p[q][s] * V_tile[s][tid] for s in 0..kBlockK. +template +__global__ void TQFlashAttentionQTiledKernel( + const T* __restrict__ query, + const T* __restrict__ k_full, + const T* __restrict__ v_full, + T* __restrict__ output, + int B, int S_q, int total_seq, int past_seq, + int num_heads, int num_kv_heads, int head_size, + float scale, int g_size) { + static_assert(kBlockK == kHeadDim, "v5 kernel requires kBlockK == kHeadDim"); + const int b = blockIdx.x; + const int h = blockIdx.y; + const int s_q_block = blockIdx.z; // index into ceil(S_q / kBlockQ) + if (b >= B || h >= num_heads) return; + + const int s_q_base = s_q_block * kBlockQ; + if (s_q_base >= S_q) return; + const int q_count = min(kBlockQ, S_q - s_q_base); // last block may be partial + + const int h_kv = h / g_size; + const int tid = threadIdx.x; + + // Smem layout. + __shared__ T smem_q[kBlockQ][kHeadDim]; + __shared__ T smem_k[kBlockK][kHeadDim]; + __shared__ T smem_v[kBlockK][kHeadDim]; + __shared__ float smem_scores[kBlockQ][kBlockK]; + __shared__ float s_m[kBlockQ]; // running max per query + __shared__ float s_l[kBlockQ]; // running denominator per query + __shared__ float s_alpha[kBlockQ]; // exp(prev_m - new_m), set each tile + + if (tid < kBlockQ) { s_m[tid] = -INFINITY; s_l[tid] = 0.0f; } + + // Load Q rows for this block. Each thread loads one element per query. + for (int q = 0; q < q_count; ++q) { + const int q_off = ((b * S_q + s_q_base + q) * num_heads + h) * kHeadDim; + smem_q[q][tid] = query[q_off + tid]; + } + __syncthreads(); + + // Per-thread accumulator: one float per query, output dim == tid. + float my_acc[kBlockQ]; + #pragma unroll + for (int q = 0; q < kBlockQ; ++q) my_acc[q] = 0.0f; + + for (int tile = 0; tile < total_seq; tile += kBlockK) { + const int tile_n = min(kBlockK, total_seq - tile); + + // Load K_tile and V_tile. Thread tid loads K_tile[tid] and V_tile[tid]. + if (tid < tile_n) { + const int row_off = (b * num_kv_heads + h_kv) * total_seq * kHeadDim + + (tile + tid) * kHeadDim; + #pragma unroll + for (int i = 0; i < kHeadDim; ++i) { + smem_k[tid][i] = k_full[row_off + i]; + smem_v[tid][i] = v_full[row_off + i]; + } + } + __syncthreads(); + + // Phase 1: compute scores. Thread tid computes score[q][tid] for all q. + if (tid < tile_n) { + const int s_kv_global = tile + tid; + #pragma unroll + for (int q = 0; q < q_count; ++q) { + const int q_pos = past_seq + s_q_base + q; + if (s_kv_global > q_pos) { + smem_scores[q][tid] = -INFINITY; + } else { + float dot = 0.0f; + #pragma unroll + for (int i = 0; i < kHeadDim; ++i) { + dot += static_cast(smem_q[q][i]) * static_cast(smem_k[tid][i]); + } + smem_scores[q][tid] = dot * scale; + } + } + } + __syncthreads(); + + // Phase 2 prep: per-query running-max bookkeeping. Thread q (q < kBlockQ) + // does the serial reduction for query q. + if (tid < q_count) { + const int q = tid; + float tile_max = -INFINITY; + for (int s = 0; s < tile_n; ++s) tile_max = fmaxf(tile_max, smem_scores[q][s]); + const float prev_m = s_m[q]; + const float m_new = fmaxf(prev_m, tile_max); + s_alpha[q] = (prev_m == -INFINITY) ? 0.0f : expf(prev_m - m_new); + float l_tile = 0.0f; + for (int s = 0; s < tile_n; ++s) { + float p = (m_new == -INFINITY) ? 0.0f : expf(smem_scores[q][s] - m_new); + smem_scores[q][s] = p; + l_tile += p; + } + s_l[q] = s_l[q] * s_alpha[q] + l_tile; + s_m[q] = m_new; + } + __syncthreads(); + + // Phase 2: each thread updates its own slice of my_acc for every query. + if (tid < kHeadDim) { + #pragma unroll + for (int q = 0; q < q_count; ++q) { + float new_contrib = 0.0f; + for (int s = 0; s < tile_n; ++s) { + new_contrib += smem_scores[q][s] * static_cast(smem_v[s][tid]); + } + my_acc[q] = s_alpha[q] * my_acc[q] + new_contrib; + } + } + __syncthreads(); + } + + // Normalize and write out. One thread per output dim. + if (tid < kHeadDim) { + for (int q = 0; q < q_count; ++q) { + float out = my_acc[q] / fmaxf(s_l[q], 1e-30f); + if (!isfinite(out)) out = 0.0f; + if (out > 65000.0f) out = 65000.0f; + if (out < -65000.0f) out = -65000.0f; + const int o_offset = ((b * S_q + s_q_base + q) * num_heads + h) * kHeadDim; + output[o_offset + tid] = static_cast(out); + } + } +} + +// ============================================================================= +// v6 q-tiled FlashAttention kernel — wmma tensor cores for the Q*K^T phase. +// +// The fragment math operates only on half-precision A/B with float +// accumulators on Ampere/Ada (sm_80+). We dispatch this kernel only when +// T == half. For bf16 / future paths, we fall through to v5. +// +// Tile shape: m=16, n=16, k=16 (the canonical Ampere fp16 mma shape). +// For (kBlockQ=32, kBlockK=64, kHeadDim=64) we have: +// 2 q-tiles × 4 s-tiles × 4 k-iterations = 32 mma_sync calls per K/V tile. +// Block has 2 warps (blockDim=64) so each warp owns half the (q-tile, s-tile) +// pairs. Phase 2 (online softmax + V update) is unchanged from v5. +template +__global__ void TQFlashAttentionWmmaKernel( + const half* __restrict__ query, + const half* __restrict__ k_full, + const half* __restrict__ v_full, + half* __restrict__ output, + int B, int S_q, int total_seq, int past_seq, + int num_heads, int num_kv_heads, int head_size, + float scale, int g_size) { + using namespace nvcuda; + static_assert(kBlockK == kHeadDim, "v6 kernel requires kBlockK == kHeadDim"); + static_assert(kHeadDim % 16 == 0, "v6 needs kHeadDim divisible by 16"); + static_assert(kBlockQ % 16 == 0, "v6 needs kBlockQ divisible by 16"); + static_assert(kBlockK % 16 == 0, "v6 needs kBlockK divisible by 16"); + + constexpr int kMmaM = 16; + constexpr int kMmaN = 16; + constexpr int kMmaK = 16; + constexpr int kQTiles = kBlockQ / kMmaM; // e.g. 32 / 16 = 2 + constexpr int kSTiles = kBlockK / kMmaN; // e.g. 64 / 16 = 4 + constexpr int kKIters = kHeadDim / kMmaK; // e.g. 64 / 16 = 4 + + const int b = blockIdx.x; + const int h = blockIdx.y; + const int s_q_block = blockIdx.z; + if (b >= B || h >= num_heads) return; + + const int s_q_base = s_q_block * kBlockQ; + if (s_q_base >= S_q) return; + const int q_count = min(kBlockQ, S_q - s_q_base); + + const int h_kv = h / g_size; + const int tid = threadIdx.x; + const int warp_id = tid / 32; + const int n_warps = blockDim.x / 32; + + __shared__ half smem_q[kBlockQ][kHeadDim]; + __shared__ half smem_k[kBlockK][kHeadDim]; + __shared__ half smem_v[kBlockK][kHeadDim]; + __shared__ float smem_scores[kBlockQ][kBlockK]; + __shared__ float s_m[kBlockQ]; + __shared__ float s_l[kBlockQ]; + __shared__ float s_alpha[kBlockQ]; + + if (tid < kBlockQ) { s_m[tid] = -INFINITY; s_l[tid] = 0.0f; } + + // Vectorised Q load: each thread loads kBlockQ / blockDim Q elements. + // For blockDim=64 and kBlockQ * kHeadDim = 32 * 64 = 2048 elements, that's + // 32 elements per thread — covered by the simple per-thread loop below. + for (int q = 0; q < q_count; ++q) { + const int q_off = ((b * S_q + s_q_base + q) * num_heads + h) * kHeadDim; + smem_q[q][tid] = query[q_off + tid]; + } + __syncthreads(); + + float my_acc[kBlockQ]; + #pragma unroll + for (int q = 0; q < kBlockQ; ++q) my_acc[q] = 0.0f; + + // Causal early-exit cap: the highest q_pos in this Q block is + // past_seq + s_q_base + q_count - 1. K rows past that index are masked to + // -inf for every query in the block, so we don't need to load or compute + // them at all. For the prompt step (Q rows uniform over 0..S-1) this + // halves the work on average. + const int max_q_pos = past_seq + s_q_base + q_count - 1; + const int last_useful_kv = min(total_seq - 1, max_q_pos); + + for (int tile = 0; tile <= last_useful_kv; tile += kBlockK) { + const int tile_n = min(kBlockK, total_seq - tile); + + if (tid < tile_n) { + const int row_off = (b * num_kv_heads + h_kv) * total_seq * kHeadDim + + (tile + tid) * kHeadDim; + // Vectorised uint4 (8-fp16) loads — see v4-lite kernel comment. + static_assert(kHeadDim % 8 == 0, "vectorised load needs kHeadDim divisible by 8"); + const uint4* k_v4 = reinterpret_cast(k_full + row_off); + const uint4* v_v4 = reinterpret_cast(v_full + row_off); + uint4* smem_k_v4 = reinterpret_cast(&smem_k[tid][0]); + uint4* smem_v_v4 = reinterpret_cast(&smem_v[tid][0]); + #pragma unroll + for (int i = 0; i < kHeadDim / 8; ++i) { + smem_k_v4[i] = k_v4[i]; + smem_v_v4[i] = v_v4[i]; + } + } + __syncthreads(); + + // ---- Phase 1: scores via wmma tensor cores ------------------------ + // Distribute (q_tile, s_tile) pairs across warps. Each warp owns one + // pair at a time and walks the K dimension in kKIters steps. + constexpr int kTotalPairs = kQTiles * kSTiles; + for (int pair = warp_id; pair < kTotalPairs; pair += n_warps) { + const int q_tile = pair / kSTiles; + const int s_tile = pair % kSTiles; + const int q_row0 = q_tile * kMmaM; + const int s_col0 = s_tile * kMmaN; + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); + + #pragma unroll + for (int kit = 0; kit < kKIters; ++kit) { + const int k_off = kit * kMmaK; + // A: smem_q rows [q_row0..q_row0+15], cols [k_off..k_off+15], row_major. + wmma::load_matrix_sync(a_frag, &smem_q[q_row0][k_off], kHeadDim); + // B: K^T treated as col_major over (k=head_dim, n=s) where the source + // storage is smem_k[s][k] (s outer). Loading from + // &smem_k[s_col0][k_off] with leading dim = kHeadDim gives the + // correct stride: column n is at offset n*kHeadDim from base. + wmma::load_matrix_sync(b_frag, &smem_k[s_col0][k_off], kHeadDim); + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + // Store the 16x16 fp32 score tile to smem_scores[q_row0..][s_col0..] + // with row-major leading dim = kBlockK. + wmma::store_matrix_sync(&smem_scores[q_row0][s_col0], c_frag, kBlockK, wmma::mem_row_major); + } + __syncthreads(); + + // Apply scaling + causal mask. Same thread layout as v5: tid plays "K row". + if (tid < tile_n) { + const int s_kv_global = tile + tid; + #pragma unroll + for (int q = 0; q < q_count; ++q) { + const int q_pos = past_seq + s_q_base + q; + if (s_kv_global > q_pos) { + smem_scores[q][tid] = -INFINITY; + } else { + smem_scores[q][tid] *= scale; + } + } + } + __syncthreads(); + + // ---- Phase 2 prep: online softmax bookkeeping (per query) --------- + if (tid < q_count) { + const int q = tid; + float tile_max = -INFINITY; + for (int s = 0; s < tile_n; ++s) tile_max = fmaxf(tile_max, smem_scores[q][s]); + const float prev_m = s_m[q]; + const float m_new = fmaxf(prev_m, tile_max); + s_alpha[q] = (prev_m == -INFINITY) ? 0.0f : expf(prev_m - m_new); + float l_tile = 0.0f; + for (int s = 0; s < tile_n; ++s) { + float p = (m_new == -INFINITY) ? 0.0f : expf(smem_scores[q][s] - m_new); + smem_scores[q][s] = p; + l_tile += p; + } + s_l[q] = s_l[q] * s_alpha[q] + l_tile; + s_m[q] = m_new; + } + __syncthreads(); + + // ---- Phase 3: V accumulator update -------------------------------- + if (tid < kHeadDim) { + #pragma unroll + for (int q = 0; q < q_count; ++q) { + float new_contrib = 0.0f; + for (int s = 0; s < tile_n; ++s) { + new_contrib += smem_scores[q][s] * __half2float(smem_v[s][tid]); + } + my_acc[q] = s_alpha[q] * my_acc[q] + new_contrib; + } + } + __syncthreads(); + } + + if (tid < kHeadDim) { + for (int q = 0; q < q_count; ++q) { + float out = my_acc[q] / fmaxf(s_l[q], 1e-30f); + if (!isfinite(out)) out = 0.0f; + if (out > 65000.0f) out = 65000.0f; + if (out < -65000.0f) out = -65000.0f; + const int o_offset = ((b * S_q + s_q_base + q) * num_heads + h) * kHeadDim; + output[o_offset + tid] = __float2half(out); + } + } +} + +template +__global__ void TQFlashAttentionKernel( + const T* __restrict__ query, + const T* __restrict__ k_full, + const T* __restrict__ v_full, + T* __restrict__ output, + int B, int S_q, int total_seq, int past_seq, + int num_heads, int num_kv_heads, int head_size, + float scale, int g_size) { + const int bs = blockIdx.z; + const int b = bs / S_q; + const int s_q = bs % S_q; + const int h = blockIdx.y; + if (b >= B || h >= num_heads || s_q >= S_q) return; + + const int h_kv = h / g_size; + const int q_pos = past_seq + s_q; + + const int q_offset = ((b * S_q + s_q) * num_heads + h) * kHeadDim; + const int k_base = (b * num_kv_heads + h_kv) * total_seq * kHeadDim; + const int o_offset = ((b * S_q + s_q) * num_heads + h) * kHeadDim; + + // Block layout: blockDim.x == kHeadDim. Each thread owns one head dim. + // (Works when kBlockK <= kHeadDim, which we ensure by template params.) + static_assert(kBlockK <= kHeadDim, "kBlockK must be <= kHeadDim for this thread layout"); + const int tid = threadIdx.x; + + __shared__ T smem_q[kHeadDim]; + __shared__ T smem_k[kBlockK][kHeadDim]; + __shared__ T smem_v[kBlockK][kHeadDim]; + __shared__ float smem_scores[kBlockK]; // also used to hold p_s after softmax + __shared__ float s_m; // running max + __shared__ float s_l; // running denominator + __shared__ float s_alpha; // exp(prev_m - new_m), shared each tile + + if (tid == 0) { s_m = -INFINITY; s_l = 0.0f; } + + // Load Q once. One element per thread. + if (tid < kHeadDim) smem_q[tid] = query[q_offset + tid]; + + float my_acc = 0.0f; // this thread's slice of the output accumulator (1 dim) + + for (int tile = 0; tile < total_seq; tile += kBlockK) { + const int tile_end = min(tile + kBlockK, total_seq); + const int tile_n = tile_end - tile; + + // Vectorised K/V load: each thread loads ONE full row via uint4 (16-byte + // = 8-fp16 per transaction). Single coalesced HBM burst per row. + static_assert(kHeadDim % 8 == 0, "vectorised load needs kHeadDim divisible by 8"); + if (tid < tile_n) { + const int row_off = k_base + (tile + tid) * kHeadDim; + const uint4* k_v4 = reinterpret_cast(k_full + row_off); + const uint4* v_v4 = reinterpret_cast(v_full + row_off); + uint4* smem_k_v4 = reinterpret_cast(&smem_k[tid][0]); + uint4* smem_v_v4 = reinterpret_cast(&smem_v[tid][0]); + #pragma unroll + for (int i = 0; i < kHeadDim / 8; ++i) { + smem_k_v4[i] = k_v4[i]; + smem_v_v4[i] = v_v4[i]; + } + } + __syncthreads(); + + // (cp.async double-buffered loads were tried and regressed at our tile + // size: per-tile commit_group/wait_group overhead exceeds the overlap + // savings when the post-load compute is small. See git history.) + + // Each thread computes the score for one s in the tile. + if (tid < tile_n) { + const int s = tid; + const int s_kv_global = tile + s; + float dot = 0.0f; + #pragma unroll + for (int i = 0; i < kHeadDim; ++i) { + dot += static_cast(smem_q[i]) * static_cast(smem_k[s][i]); + } + dot *= scale; + if (s_kv_global > q_pos) dot = -INFINITY; // causal + smem_scores[s] = dot; + } + __syncthreads(); + + // Thread 0 does the running-max bookkeeping for this tile. + if (tid == 0) { + float tile_max = -INFINITY; + for (int s = 0; s < tile_n; ++s) tile_max = fmaxf(tile_max, smem_scores[s]); + const float prev_m = s_m; + const float m_new = fmaxf(prev_m, tile_max); + s_alpha = (prev_m == -INFINITY) ? 0.0f : expf(prev_m - m_new); + float l_tile = 0.0f; + for (int s = 0; s < tile_n; ++s) { + float p = (m_new == -INFINITY) ? 0.0f : expf(smem_scores[s] - m_new); + smem_scores[s] = p; + l_tile += p; + } + s_l = s_l * s_alpha + l_tile; + s_m = m_new; + } + __syncthreads(); + + // All threads update their own slice of the accumulator. + if (tid < kHeadDim) { + float new_contrib = 0.0f; + for (int s = 0; s < tile_n; ++s) { + new_contrib += smem_scores[s] * static_cast(smem_v[s][tid]); + } + my_acc = s_alpha * my_acc + new_contrib; + } + __syncthreads(); + } + + // Normalize and write out. + if (tid < kHeadDim) { + float out = my_acc / fmaxf(s_l, 1e-30f); + if (!isfinite(out)) out = 0.0f; + if (out > 65000.0f) out = 65000.0f; + if (out < -65000.0f) out = -65000.0f; + output[o_offset + tid] = static_cast(out); + } +} + +template +Status LaunchTQAttention(cudaStream_t stream, + int B, int S_q, int total_seq, int past_seq, + int num_heads, int num_kv_heads, int head_size, + float scale, + const T* query, const T* k_full, const T* v_full, + float* scores, T* output) { + const int g_size = num_heads / num_kv_heads; + (void)scores; // unused in the fused path + + // v6 wmma kernel for the prompt step on fp16 inputs. Falls through to v5 + // for bf16 / non-fp16 paths. + if constexpr (std::is_same::value) { + if (head_size == 64 && S_q > 1) { + constexpr int kHeadDim = 64; + constexpr int kBlockK = 64; + constexpr int kBlockQ = 32; // larger sizes (64) hurt occupancy at 32K via reg pressure + dim3 grid(B, num_heads, (S_q + kBlockQ - 1) / kBlockQ); + dim3 block(kHeadDim); + TQFlashAttentionWmmaKernel<<>>( + query, k_full, v_full, output, B, S_q, total_seq, past_seq, + num_heads, num_kv_heads, head_size, scale, g_size); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + return Status::OK(); + } + } + // v5 q-tiled kernel for the prompt step (S_q > 1): each block handles a + // group of consecutive Q rows, sharing K/V loads across them. For + // decode (S_q == 1) we drop to v4-lite (one block per s_q is already + // optimal there). + if (head_size == 64 && S_q > 1) { + constexpr int kHeadDim = 64; + constexpr int kBlockK = 64; + constexpr int kBlockQ = 32; // tuned: bigger values help up to 32, then occupancy drops + dim3 grid(B, num_heads, (S_q + kBlockQ - 1) / kBlockQ); + dim3 block(kHeadDim); + TQFlashAttentionQTiledKernel<<>>( + query, k_full, v_full, output, B, S_q, total_seq, past_seq, + num_heads, num_kv_heads, head_size, scale, g_size); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + return Status::OK(); + } + // (head_dim=128 q-tiling deferred — needs kBlockK split from kHeadDim to + // fit shared memory. Falls through to v4-lite below for hd=128 prompts.) + // v4 fused kernel: decode-step path (S_q == 1) — one block per (b, h), + // online softmax over K/V tiles. blockDim == kHeadDim. + if (head_size == 64) { + constexpr int kHeadDim = 64; + constexpr int kBlockK = 64; + dim3 grid(1, num_heads, B * S_q); + dim3 block(kHeadDim); + TQFlashAttentionKernel<<>>( + query, k_full, v_full, output, B, S_q, total_seq, past_seq, + num_heads, num_kv_heads, head_size, scale, g_size); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + return Status::OK(); + } + if (head_size == 128) { + constexpr int kHeadDim = 128; + constexpr int kBlockK = 64; + dim3 grid(1, num_heads, B * S_q); + dim3 block(kHeadDim); + TQFlashAttentionKernel<<>>( + query, k_full, v_full, output, B, S_q, total_seq, past_seq, + num_heads, num_kv_heads, head_size, scale, g_size); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + return Status::OK(); + } + + // Fallback to the v3.5 split kernels for unsupported head sizes (96, 256, …). + // Scores kernel. + { + dim3 grid((total_seq + 255) / 256, num_heads, B * S_q); + dim3 block(256); + TQScoresKernel<<>>( + query, k_full, scores, B, S_q, total_seq, past_seq, + num_heads, num_kv_heads, head_size, scale, g_size); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + } + // Softmax kernel. + { + int block_size = 256; + while (block_size > total_seq && block_size > 32) block_size /= 2; + dim3 grid(1, num_heads, B * S_q); + dim3 block(block_size); + TQSoftmaxRowKernel<<>>( + scores, B, num_heads, S_q, total_seq); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + } + // Output kernel. + { + dim3 grid(1, num_heads, B * S_q); + dim3 block(64); // == head_size, simple + TQOutputKernel<<>>( + scores, v_full, output, B, S_q, total_seq, + num_heads, num_kv_heads, head_size, g_size); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + } + return Status::OK(); +} + +// ============================================================================= +// LaunchTurboQuantAttention: full attention via bulk dequant + standard fp16. +// ============================================================================= + +template +Status LaunchTurboQuantAttention( + const cudaDeviceProp& device_prop, + Stream* stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data) { + static_assert(std::is_same::value, + "TurboQuant only supports uint8 cache type (T_CACHE = uint8)"); + + cudaStream_t cuda_stream = static_cast(stream->GetHandle()); + + const int B = parameters.batch_size; + const int H_kv = parameters.kv_num_heads; + const int new_seq = parameters.sequence_length; + const int past_seq = parameters.seqlen_past_kv_cache; + const int total_seq = past_seq + new_seq; + const int max_seq = parameters.seqlen_present_kv_cache; + const int D = parameters.head_size; + const int kbits = parameters.key_quant_bits; + const int vbits = parameters.value_quant_bits; + + // Cache slot last dim = max(K_slot, V_slot). + const int k_slot = (D * kbits + 7) / 8 + 2; + const int v_slot = (D * vbits + 7) / 8 + 4; + const int slot_last_dim = (k_slot > v_slot) ? k_slot : v_slot; + + // Materialize codebook as fp32 device buffer (small, max 16 entries). + // v3: do this entirely on-device so we don't force a host sync on every + // attention call. The codebook itself is constant per session, but we + // re-derive the fp32 view per call to keep the orchestrator stateless. + // 16 floats = 64 bytes; this is dominated by the kernel launch latency. + const int n_centroids = 1 << kbits; + float* d_codebook = nullptr; + CUDA_RETURN_IF_ERROR(cudaMallocAsync(&d_codebook, n_centroids * sizeof(float), cuda_stream)); + TQConvertCodebookKernel<<<1, 32, 0, cuda_stream>>>(data.k_codebook, d_codebook, n_centroids); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + + // v2 PATH: encode incoming K/V → present cache (compressed bytes), decode + // entire present cache → fp16 K/V buffers, run attention math directly via + // our simple custom kernels (TQScoresKernel + TQSoftmaxRowKernel + TQOutputKernel), + // then write to data.output. This produces real attention results with real + // tokens — same output a fp16 GQA would give, modulo the lossy cache. + + // Apply RoPE to incoming Q and K when the GQA op was configured with + // do_rotary=1 (e.g. LFM2/LFM2.5/Qwen3 ONNX exports inline RoPE inside the + // GQA op). The standard QkvToContext path does this via + // LaunchUnpackRoPEAppend; the TQ path bypasses that, so without this step + // the cache encodes positionless K and the model produces near-random + // logits even though no NaN appears. + T* d_Q_rot = nullptr; + T* d_K_rot = nullptr; + int* d_past_seq_lens = nullptr; + // The shared LaunchRotaryEmbeddingKernel is instantiated for `half` and ORT's + // `BFloat16` wrapper, NOT for the CUDA `__nv_bfloat16` we use here. Skip the + // bf16 path — wandler today only ships fp16 LLMs through TurboQuant and + // adding the bf16 instantiation would require touching the shared rotary + // template definitions. The fp16 path correctly applies inline RoPE. + const bool apply_rotary = parameters.do_rotary && std::is_same::value; + if constexpr (std::is_same::value) { + if (apply_rotary) { + // Q has shape (B, S_q, num_heads * D), K has shape (B, S_q, kv_num_heads * D), both BSNH. + const size_t q_elems = static_cast(B) * new_seq * parameters.num_heads * D; + const size_t k_elems = static_cast(B) * new_seq * H_kv * D; + CUDA_RETURN_IF_ERROR(cudaMallocAsync(&d_Q_rot, q_elems * sizeof(T), cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMallocAsync(&d_K_rot, k_elems * sizeof(T), cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMallocAsync(&d_past_seq_lens, B * sizeof(int), cuda_stream)); + // past_sequence_lengths[b] = past_seq for every batch element. + { + std::vector host_past(B, past_seq); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(d_past_seq_lens, host_past.data(), + B * sizeof(int), cudaMemcpyHostToDevice, cuda_stream)); + } + const int rotary_dim = (parameters.rotary_dim > 0) ? parameters.rotary_dim : D; + constexpr int kPositionIdsFormat = 2; // use past_sequence_lengths[b] + s + constexpr int kMaxSequenceLength = 1 << 20; // bounds-check upper limit; only used by formats 0/1 + ORT_RETURN_IF_ERROR((LaunchRotaryEmbeddingKernel( + cuda_stream, d_Q_rot, data.query, /*position_ids=*/nullptr, d_past_seq_lens, + data.cos_cache, data.sin_cache, B, new_seq, parameters.num_heads, D, + rotary_dim, kMaxSequenceLength, kPositionIdsFormat, parameters.rotary_interleaved, + device_prop.maxThreadsPerBlock, /*is_input_bnsh_format=*/false))); + ORT_RETURN_IF_ERROR((LaunchRotaryEmbeddingKernel( + cuda_stream, d_K_rot, data.key, /*position_ids=*/nullptr, d_past_seq_lens, + data.cos_cache, data.sin_cache, B, new_seq, H_kv, D, + rotary_dim, kMaxSequenceLength, kPositionIdsFormat, parameters.rotary_interleaved, + device_prop.maxThreadsPerBlock, /*is_input_bnsh_format=*/false))); + } + } + + // Allocate temporary fp16 K/V buffers for the full cache. + const size_t kv_elements = static_cast(B) * H_kv * total_seq * D; + T *d_K_fp16 = nullptr, *d_V_fp16 = nullptr; + CUDA_RETURN_IF_ERROR(cudaMallocAsync(&d_K_fp16, kv_elements * sizeof(T), cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMallocAsync(&d_V_fp16, kv_elements * sizeof(T), cuda_stream)); + + // Copy past_key/past_value content into the first past_seq slots of present_key/value. + // The encode kernel will only write the new [past_seq, total_seq) slots; the decode + // kernel reads ALL [0, total_seq) slots, so the past must already live in present. + // (When past_present_share_buffer is true the same buffer is reused and this copy + // is a no-op for content but still safe — both pointers refer to the same memory.) + if (past_seq > 0 && data.past_key != nullptr && data.past_value != nullptr && + reinterpret_cast(data.past_key) != + reinterpret_cast(data.present_key)) { + const size_t past_bytes_per_layer = + static_cast(B) * H_kv * past_seq * slot_last_dim; + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + data.present_key, data.past_key, past_bytes_per_layer, + cudaMemcpyDeviceToDevice, cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + data.present_value, data.past_value, past_bytes_per_layer, + cudaMemcpyDeviceToDevice, cuda_stream)); + } + + // v7 path: + // 1. Encode new K/V into the packed cache for future decode steps to read. + // 2. Decode runs on PAST slots only — the new slots get filled by step 3. + // 3. Copy fresh fp16 K/V (post-RoPE for K) from BSNH input straight into + // the BNSH K_out / V_out scratch buffer at slots [past_seq, total_seq). + const T* k_in = apply_rotary ? d_K_rot : data.key; + Status status = DispatchEncodeDecode( + cuda_stream, B, H_kv, new_seq, /*decode_count=*/past_seq, + max_seq, past_seq, D, + kbits, vbits, slot_last_dim, parameters.norm_correction, + k_in, data.value, d_codebook, + reinterpret_cast(data.present_key), + reinterpret_cast(data.present_value), + d_K_fp16, d_V_fp16); + + if (status.IsOK()) { + Status copy_status = LaunchCopyFreshKV( + cuda_stream, k_in, data.value, d_K_fp16, d_V_fp16, + B, new_seq, H_kv, total_seq, past_seq, D); + if (!copy_status.IsOK()) status = copy_status; + } + + // Option ε: for first-prompt steps (past_seq == 0, S_q > 1), delegate the + // attention math to ORT's standard FlashAttention. We have the EXACT + // pre-quantization fp16 K/V (in d_K_rot / data.value or data.key / + // data.value) — running FA on those is much faster than walking the + // packed cache with our custom kernel. The encode kernel above has + // already populated the packed cache for FUTURE decode steps to read. + // + // This mirrors vLLM's TurboQuant prefill path (turboquant_attn.py:542): + // first-chunk prefill never touches the quantized cache; FA runs on the + // raw new K/V. We only fall through to the custom TQ kernel for + // continuation prefill / decode steps where past tokens are in the cache. + if constexpr (std::is_same::value) { + if (status.IsOK() && past_seq == 0 && new_seq > 1 && + data.softmax_lse != nullptr && data.padded_seq_lens != nullptr) { + // FA "no internal append" mode: pass kcache/vcache fully populated + // (d_K_fp16 / d_V_fp16 already hold the post-RoPE K and raw V from + // LaunchCopyFreshKV) plus nullptr for new_k/new_v. seqlens_k passes + // the TOTAL length already in kcache (= padded_seq_lens for first + // prompt) — same convention the standard GQA FlashAttention path uses. + void* q_in_ptr = const_cast(reinterpret_cast( + apply_rotary ? d_Q_rot : data.query)); + Status fa_status = onnxruntime::flash::mha_fwd_kvcache( + device_prop, cuda_stream, + /*q=*/q_in_ptr, + /*kcache=*/d_K_fp16, + /*vcache=*/d_V_fp16, + /*new_k=*/nullptr, // already appended via LaunchCopyFreshKV + /*new_v=*/nullptr, + /*out=*/data.output, + /*softmax_lse=*/reinterpret_cast(data.softmax_lse), + /*seqlens_k=*/data.padded_seq_lens, + /*rotary_cos=*/nullptr, // pre-rotated tensors passed + /*rotary_sin=*/nullptr, + /*cache_batch_idx=*/nullptr, + /*leftpad_k=*/nullptr, + /*head_sink=*/nullptr, + /*block_table=*/nullptr, + B, parameters.num_heads, H_kv, D, + /*seqlen_q=*/new_seq, + /*seqlen_k=*/total_seq, + /*seqlen_k_new=*/0, // 0 because not appending + /*rotary_dim=*/0, + /*scale=*/(parameters.scale != 0.0f) ? parameters.scale + : (1.0f / sqrtf(static_cast(D))), + /*softcap=*/parameters.softcap, + /*is_causal=*/true, + /*is_bf16=*/false, + /*use_smooth_softmax=*/parameters.use_smooth_softmax, + /*past_bsnh=*/false, // d_K_fp16 / d_V_fp16 are BNSH + /*num_splits=*/parameters.num_splits, + /*lse_accum=*/reinterpret_cast(data.softmax_lse_accum), + /*out_accum=*/reinterpret_cast(data.out_accum), + /*local_window=*/parameters.local_window_size - 1, + /*is_rotary_interleaved=*/false, + /*is_packed_qkv=*/false); + cudaFreeAsync(d_codebook, cuda_stream); + cudaFreeAsync(d_K_fp16, cuda_stream); + cudaFreeAsync(d_V_fp16, cuda_stream); + if (d_Q_rot != nullptr) cudaFreeAsync(d_Q_rot, cuda_stream); + if (d_K_rot != nullptr) cudaFreeAsync(d_K_rot, cuda_stream); + if (d_past_seq_lens != nullptr) cudaFreeAsync(d_past_seq_lens, cuda_stream); + return fa_status; + } + } + + if (!status.IsOK()) { + cudaFreeAsync(d_codebook, cuda_stream); + cudaFreeAsync(d_K_fp16, cuda_stream); + cudaFreeAsync(d_V_fp16, cuda_stream); + if (d_Q_rot != nullptr) cudaFreeAsync(d_Q_rot, cuda_stream); + if (d_K_rot != nullptr) cudaFreeAsync(d_K_rot, cuda_stream); + if (d_past_seq_lens != nullptr) cudaFreeAsync(d_past_seq_lens, cuda_stream); + return status; + } + + // Step 3: allocate scores buffer [B, num_heads, S_q, total_seq] fp32 only + // for the v3.5 fallback path (head sizes other than 64/128). The v4 fused + // kernel does online softmax in registers and never touches HBM scores. + const bool use_fused_attention = (D == 64 || D == 128); + const size_t scores_elements = + use_fused_attention ? 0 : (static_cast(B) * parameters.num_heads * new_seq * total_seq); + float* d_scores = nullptr; + if (scores_elements > 0) { + CUDA_RETURN_IF_ERROR(cudaMallocAsync(&d_scores, scores_elements * sizeof(float), cuda_stream)); + } + + const float scale = + (parameters.scale != 0.0f) ? parameters.scale : (1.0f / sqrtf(static_cast(D))); + + // Use the RoPE-rotated Q when do_rotary=1. + const T* q_in = apply_rotary ? d_Q_rot : data.query; + Status attn_status = LaunchTQAttention(cuda_stream, B, new_seq, total_seq, past_seq, + parameters.num_heads, H_kv, D, scale, + q_in, d_K_fp16, d_V_fp16, + d_scores, data.output); + status = attn_status; + + cudaFreeAsync(d_codebook, cuda_stream); + cudaFreeAsync(d_K_fp16, cuda_stream); + cudaFreeAsync(d_V_fp16, cuda_stream); + if (d_scores != nullptr) cudaFreeAsync(d_scores, cuda_stream); + if (d_Q_rot != nullptr) cudaFreeAsync(d_Q_rot, cuda_stream); + if (d_K_rot != nullptr) cudaFreeAsync(d_K_rot, cuda_stream); + if (d_past_seq_lens != nullptr) cudaFreeAsync(d_past_seq_lens, cuda_stream); + return status; +} + +#if 0 // V2 code preserved for reference: full fp16-delegation path that calls + // the existing GQA attention via QkvToContext. Currently disabled + // because the unfused path's PrepareQKV requires more buffer setup than + // we have here and was hitting illegal-memory-access in attention_transpose. + // The v1 shortcut above is sufficient to validate end-to-end wiring + // (encode + decode + structural session.run() success). + cudaStream_t cuda_stream = static_cast(stream->GetHandle()); + contrib::GroupQueryAttentionParameters fp_params = parameters; + fp_params.kv_quant_method = KVQuantMethod::NONE; + fp_params.k_quant_type = KVQuantizationType::NONE; + fp_params.v_quant_type = KVQuantizationType::NONE; + fp_params.kv_cache_bit_width = 0; + fp_params.is_first_prompt = true; // tells the kernel "no past, just compute attention from K/V" + fp_params.sequence_length = total_seq; + fp_params.total_sequence_length = total_seq; + fp_params.seqlen_past_kv_cache = 0; + fp_params.seqlen_present_kv_cache = total_seq; + fp_params.past_present_share_buffer = false; + + GroupQueryAttentionData fp_data{}; + fp_data.query = data.query; + // Use the decoded K/V as the "current" K/V for first-prompt attention. + // We need just the new_seq tail (so the kernel only computes attention for + // the new positions), but the keys are over the full total_seq. + // For minimal v1 we run "attention over full sequence" and only the last + // new_seq Q rows contribute to the output that wandler reads — same shape. + fp_data.key = d_K_fp16; + fp_data.value = d_V_fp16; + fp_data.cos_cache = data.cos_cache; + fp_data.sin_cache = data.sin_cache; + fp_data.head_sink = data.head_sink; + fp_data.position_ids = data.position_ids; + fp_data.output = data.output; + + // We need a present_key / present_value that the fp16 path can write to. + // Allocate fp16 stub buffers (shape: B * H_kv * total_seq * D). + T* d_present_k_fp16 = nullptr; + T* d_present_v_fp16 = nullptr; + CUDA_RETURN_IF_ERROR(cudaMallocAsync(&d_present_k_fp16, kv_elements * sizeof(T), cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMallocAsync(&d_present_v_fp16, kv_elements * sizeof(T), cuda_stream)); + fp_data.present_key = d_present_k_fp16; + fp_data.present_value = d_present_v_fp16; + // past_key / past_value left as nullptr for is_first_prompt path. + + // We also need cublas + the rest of the GroupQueryAttention surrounding the + // call. For v1 we delegate to QkvToContext directly with a fresh cublas + // handle from the stream's parent provider — but we don't have access here. + // Simpler v1: use a global cublas handle obtained via the stream. + cublasHandle_t cublas; + cublasCreate(&cublas); + cublasSetStream(cublas, cuda_stream); + + // Recompute past_seq_lens / total_seq_lens for the inner call. + // We use the existing buffers from data (already populated by the outer caller). + fp_data.past_seq_lens = data.past_seq_lens; + fp_data.total_seq_lens = data.total_seq_lens; + fp_data.padded_seq_lens = data.padded_seq_lens; + fp_data.softmax_lse = data.softmax_lse; + fp_data.softmax_lse_accum = data.softmax_lse_accum; + fp_data.out_accum = data.out_accum; + fp_data.qkv_buffer = data.qkv_buffer; + fp_data.fmha_buffer = data.fmha_buffer; + fp_data.k = data.k; + fp_data.v = data.v; + // Use the unfused path for simplicity: it has fewer constraints and works for + // any head_size. It's the "math" path that does Q@K^T -> softmax -> @V via + // cublas + element-wise softmax kernel. + fp_data.use_flash_attention = false; + fp_data.use_memory_efficient_attention = false; + fp_data.use_xqa = false; + fp_data.use_unfused = true; + + // Allocate unfused scratch buffers. + const size_t Bs = static_cast(B); + const size_t N_q = static_cast(parameters.num_heads); + const size_t S_q = static_cast(total_seq); + const size_t Hs = static_cast(D); + const size_t S_kv = static_cast(total_seq); + + auto align256 = [](size_t v) -> size_t { return ((v + 255) / 256) * 256; }; + const size_t q_bnsh_bytes = align256(Bs * N_q * S_q * Hs * sizeof(T)); + const size_t y_bnsh_bytes = align256(Bs * N_q * S_q * Hs * sizeof(T)); + const size_t ws_bytes = onnxruntime::contrib::cuda::GetUnfusedAttentionWorkspaceSize( + static_cast(Bs), static_cast(N_q), static_cast(S_q), static_cast(S_kv)); + + uint8_t* unfused_scratch = nullptr; + CUDA_RETURN_IF_ERROR(cudaMallocAsync(&unfused_scratch, q_bnsh_bytes + y_bnsh_bytes + ws_bytes, cuda_stream)); + fp_data.unfused_q_bnsh = reinterpret_cast(unfused_scratch); + fp_data.unfused_y_bnsh = reinterpret_cast(unfused_scratch + q_bnsh_bytes); + fp_data.unfused_workspace = reinterpret_cast(unfused_scratch + q_bnsh_bytes + y_bnsh_bytes); + + Status final_status = QkvToContext(device_prop, cublas, stream, fp_params, fp_data); + + cublasDestroy(cublas); + cudaFreeAsync(d_codebook, cuda_stream); + cudaFreeAsync(d_K_fp16, cuda_stream); + cudaFreeAsync(d_V_fp16, cuda_stream); + cudaFreeAsync(d_present_k_fp16, cuda_stream); + cudaFreeAsync(d_present_v_fp16, cuda_stream); + cudaFreeAsync(unfused_scratch, cuda_stream); + return final_status; +#endif // 0 — v2 fp16-delegation code + +template Status LaunchTurboQuantAttention( + const cudaDeviceProp&, Stream*, + contrib::GroupQueryAttentionParameters&, + GroupQueryAttentionData&); + +template Status LaunchTurboQuantAttention<__nv_bfloat16, uint8_t>( + const cudaDeviceProp&, Stream*, + contrib::GroupQueryAttentionParameters&, + GroupQueryAttentionData<__nv_bfloat16, uint8_t>&); + +// ============================================================================= +// Kept for backward compat: roundtrip helper used by validation. +// ============================================================================= + +template +Status LaunchTurboQuantEncodeDecodeRoundtrip( + const cudaDeviceProp& /*device_prop*/, + Stream* /*stream*/, + int /*batch_size*/, + int /*n_kv_heads*/, + int /*seq_len*/, + int /*head_size*/, + int /*key_bits*/, + int /*value_bits*/, + bool /*norm_correction*/, + const T* /*K*/, + const T* /*V*/, + const T* /*k_codebook*/, + const T* /*hadamard*/, + T* /*K_recon*/, + T* /*V_recon*/) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "Roundtrip helper deprecated; use LaunchTurboQuantAttention or test via gtest harness."); +} + +template Status LaunchTurboQuantEncodeDecodeRoundtrip( + const cudaDeviceProp&, Stream*, int, int, int, int, int, int, bool, + const half*, const half*, const half*, const half*, half*, half*); + +template Status LaunchTurboQuantEncodeDecodeRoundtrip<__nv_bfloat16>( + const cudaDeviceProp&, Stream*, int, int, int, int, int, int, bool, + const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, + const __nv_bfloat16*, __nv_bfloat16*, __nv_bfloat16*); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime + +// ============================================================================= +// C-style entry point retained for tests that may dlopen this lib. +// ============================================================================= +extern "C" __attribute__((visibility("default"))) int RunTurboQuantRoundtrip_fp16( + int /*batch_size*/, int /*n_kv_heads*/, int /*seq_len*/, int /*head_size*/, + int /*key_bits*/, int /*value_bits*/, int /*norm_correction*/, + const void* /*d_K*/, const void* /*d_V*/, const void* /*d_codebook*/, + void* /*d_K_recon*/, void* /*d_V_recon*/, + void* /*cuda_stream*/) { + return -1; // deprecated +} diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_turboquant_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_turboquant_impl.h new file mode 100644 index 0000000000000..1c16b0b7238b8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_turboquant_impl.h @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "core/providers/cuda/shared_inc/cuda_utils.h" +#include "core/framework/allocator.h" +#include "contrib_ops/cuda/bert/attention_data.h" +#include "contrib_ops/cpu/bert/attention_parameters.h" +#include "core/providers/cuda/cuda_common.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Launches the TurboQuant CUDA path for GroupQueryAttention. +// +// Routing: +// - When parameters.is_first_prompt: encodes incoming K/V into the present cache +// (TurboQuant slot layout) AND runs fp16 attention output (using non-encoded inputs). +// - When sequence_length == 1 (decode): encodes the single new K/V token and appends +// to present cache, then computes attention scores in rotated space against the +// packed cache via TQDecodeScoreKernel, applies softmax, and computes the V +// weighted sum via TQDecodeWeightedSumKernel. +// +// Only `(MLFloat16, uint8_t)` and `(BFloat16, uint8_t)` (T, U) instantiations are valid. +// +// Returns Status::OK on success, or an error if (head_size, key_bits, value_bits) +// is unsupported (currently head_size in {64, 128, 256}, key_bits in {3, 4}, +// value_bits in {3, 4}). +template +Status LaunchTurboQuantAttention( + const cudaDeviceProp& device_prop, + Stream* stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data); + +// Standalone wrapper for unit tests: takes raw K and V, encodes them to TurboQuant +// format, decodes them back to fp16, and writes the result to k_recon / v_recon. +// +// This exercises the encode/decode kernels without the full attention pipeline. +// Used by gtests in test/contrib_ops/turboquant_kv_test.cc. +template +Status LaunchTurboQuantEncodeDecodeRoundtrip( + const cudaDeviceProp& device_prop, + Stream* stream, + int batch_size, + int n_kv_heads, + int seq_len, + int head_size, + int key_bits, // 3 or 4 + int value_bits, // 3 or 4 + bool norm_correction, + const T* K, // (B, H_kv, S, D) input fp16/bf16 keys + const T* V, // (B, H_kv, S, D) input fp16/bf16 values + const T* k_codebook, // (2^key_bits) static centroids + const T* hadamard, // (D, D) Walsh-Hadamard matrix (unused — we apply FWHT in kernel) + T* k_recon, // (B, H_kv, S, D) output: K reconstructed in rotated space + T* v_recon // (B, H_kv, S, D) output: V reconstructed (uniform dequant) +); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime