Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions include/onnxruntime/core/framework/int3.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <cassert>
#include <type_traits>
#include "core/common/common.h"
#include <gsl/gsl>

namespace onnxruntime {

// Stores 8 packed 3-bit unsigned elements in 3 bytes (24 bits exactly).
//
// Bit layout (within the 24-bit word, little-endian):
// word = byte[0] | (byte[1] << 8) | (byte[2] << 16)
// value_i = (word >> (i * 3)) & 0x7 for i in [0, 8)
//
// This layout matches vLLM's TurboQuant store kernel
// (vllm/v1/attention/ops/triton_turboquant_store.py) so binaries produced by
// either runtime are interchangeable.
//
// Used for:
// - K-cache 3-bit Lloyd-Max codebook indices in TurboQuant.
// - V-cache 3-bit uniform-quant indices when value_quant_bits == 3.
//
// Storage tip: callers should always operate on whole 8-element groups
// (i.e. tensor extents along the packed axis must be multiples of 8). The
// 3-bit encoding has no clean partial-byte representation otherwise.
struct UInt3x8 {
static constexpr uint8_t min_val = 0;
static constexpr uint8_t max_val = 7;
static constexpr size_t kElementsPerPack = 8;
static constexpr size_t kBytesPerPack = 3;

std::byte bytes_[kBytesPerPack]{};

UInt3x8() = default;

// Construct from 8 unpacked uint8 elements. All must be in [0, 7].
explicit UInt3x8(const uint8_t (&values)[kElementsPerPack]) {
uint32_t word = 0;
for (size_t i = 0; i < kElementsPerPack; ++i) {
assert(values[i] <= max_val);
word |= (static_cast<uint32_t>(values[i]) & 0x7u) << (i * 3);
}
bytes_[0] = static_cast<std::byte>(word & 0xFFu);
bytes_[1] = static_cast<std::byte>((word >> 8) & 0xFFu);
bytes_[2] = static_cast<std::byte>((word >> 16) & 0xFFu);
}

inline uint8_t GetElem(size_t index) const {
assert(index < kElementsPerPack);
const uint32_t word = static_cast<uint32_t>(bytes_[0]) |
(static_cast<uint32_t>(bytes_[1]) << 8) |
(static_cast<uint32_t>(bytes_[2]) << 16);
return static_cast<uint8_t>((word >> (index * 3)) & 0x7u);
}

inline void SetElem(size_t index, uint8_t val) {
assert(index < kElementsPerPack);
assert(val <= max_val);
uint32_t word = static_cast<uint32_t>(bytes_[0]) |
(static_cast<uint32_t>(bytes_[1]) << 8) |
(static_cast<uint32_t>(bytes_[2]) << 16);
const uint32_t mask = ~(0x7u << (index * 3));
word = (word & mask) | ((static_cast<uint32_t>(val) & 0x7u) << (index * 3));
bytes_[0] = static_cast<std::byte>(word & 0xFFu);
bytes_[1] = static_cast<std::byte>((word >> 8) & 0xFFu);
bytes_[2] = static_cast<std::byte>((word >> 16) & 0xFFu);
}

// Number of UInt3x8 packs required to store n unpacked 3-bit elements.
// Caller must ensure n is a multiple of 8.
static constexpr size_t CalcNumPacks(size_t num_3bit_elems) {
return num_3bit_elems / kElementsPerPack;
}

// Bytes required to store n unpacked 3-bit elements.
static constexpr size_t CalcNumBytes(size_t num_3bit_elems) {
return CalcNumPacks(num_3bit_elems) * kBytesPerPack;
}

// Bulk unpack: turn N packed groups (3 bytes each) into N*8 bytes [0, 7].
static bool Unpack(gsl::span<uint8_t> dst, gsl::span<const UInt3x8> src) {
if (dst.size() != src.size() * kElementsPerPack) {
return false;
}
for (size_t i = 0; i < src.size(); ++i) {
for (size_t j = 0; j < kElementsPerPack; ++j) {
dst[i * kElementsPerPack + j] = src[i].GetElem(j);
}
}
return true;
}

// Bulk pack: turn N*8 bytes [0, 7] into N packed groups.
static bool Pack(gsl::span<UInt3x8> dst, gsl::span<const uint8_t> src) {
if (src.size() != dst.size() * kElementsPerPack) {
return false;
}
for (size_t i = 0; i < dst.size(); ++i) {
uint8_t buf[kElementsPerPack];
for (size_t j = 0; j < kElementsPerPack; ++j) {
buf[j] = src[i * kElementsPerPack + j];
assert(buf[j] <= max_val);
}
dst[i] = UInt3x8(buf);
}
return true;
}
};

static_assert(sizeof(UInt3x8) == 3, "UInt3x8 must be exactly 3 bytes");

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -537,3 +537,23 @@ static const char* const kOrtSessionOptionsRecordEpGraphAssignmentInfo = "sessio
// - "0": disable. (default)
// - "1": enable.
static const char* const kOrtSessionOptionEpEnableWeightlessEpContextNodes = "ep.enable_weightless_ep_context_nodes";

// TurboQuant KV-cache rewrite preset. When set, the TurboQuantKVFusion graph
// transformer rewrites every GroupQueryAttention node in the model to use
// TurboQuant KV-cache compression at session-create time. This means the user
// can load a stock fp16 q4f16 .onnx model from HuggingFace and opt in to
// TurboQuant via runtime config alone — no offline model conversion needed.
//
// Recognized values:
// - "" / "none" / "off": disabled (default).
// - "turboquant_4bit_nc": 4-bit K, 4-bit V, norm correction on.
// - "turboquant_k3v4_nc": 3-bit K, 4-bit V, norm correction on.
// - "turboquant_3bit_nc": 3-bit K, 3-bit V, norm correction on.
//
// Requires CUDA EP (the TurboQuant kernels live there). No-op on CPU.
static const char* const kOrtSessionOptionsTurboQuantKVMethod = "optimization.turboquant_kv_method";

// TurboQuant boundary-protection layers. Number of attention layers at the
// start AND the end of the network to leave at fp16 (matching vLLM's behaviour).
// Default "2". Set to "0" to TurboQuant every GQA layer.
static const char* const kOrtSessionOptionsTurboQuantKVBoundary = "optimization.turboquant_kv_boundary";
22 changes: 22 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,28 @@ enum class KVQuantizationType : int {
PER_CHANNEL = 2,
};

// Enum to select KV-cache compression method.
//
// CLASSIC modes use a single global scale per K/V tensor (PER_TENSOR) or one
// scale per channel (PER_CHANNEL) and store linearly-quantized integers. The
// scale tensors are passed via inputs `k_scale` and `v_scale`.
//
// TURBOQUANT applies a Walsh-Hadamard rotation to keys before scalar
// quantization with a static Lloyd-Max codebook (3- or 4-bit), and uses
// uniform asymmetric quantization for values. Codebook + Hadamard are graph
// initializers; per-token vec_norm + per-token v_scale/v_zero live alongside
// the packed bytes in the cache. See:
// onnxruntime/contrib_ops/cuda/bert/group_query_attention_turboquant.cuh
// onnxruntime/contrib_ops/webgpu/bert/flash_attention_*_turboquant.wgsl.template
//
// TURBOQUANT preserves attention semantics: scoring runs in the rotated space
// because Hadamard is orthogonal, so Q.K = (Q@H).(k_hat@H) * ||k||.
enum class KVQuantMethod : int {
NONE = 0,
CLASSIC = 1, // existing int4/int8/fp8 path via KVQuantizationType + k_scale/v_scale
TURBOQUANT = 2, // Hadamard + Lloyd-Max keys, uniform asymmetric values
};

constexpr bool LAYOUT_BSNH = false;
constexpr bool LAYOUT_BNSH = true;

Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ struct GroupQueryAttentionParameters : AttentionParameters {
KVQuantizationType k_quant_type = KVQuantizationType::NONE;
KVQuantizationType v_quant_type = KVQuantizationType::NONE;
int kv_cache_bit_width = 0;

// TurboQuant KV cache parameters (mutually exclusive with k_quant_type/v_quant_type
// when kv_quant_method == TURBOQUANT). The codebook + Hadamard pointers live in
// the GroupQueryAttentionData struct (CUDA-side) since they're device pointers.
KVQuantMethod kv_quant_method = KVQuantMethod::NONE;
int key_quant_bits = 0; // 3 or 4 when kv_quant_method == TURBOQUANT
int value_quant_bits = 0; // 3 or 4 when kv_quant_method == TURBOQUANT
bool norm_correction = false;
};

// Parameters deduced from node attributes and inputs/outputs.
Expand Down
24 changes: 14 additions & 10 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,20 @@ Status CheckPast(const T* past_key, const T* past_value, int batch_size, int kv_

// For 4-bit quantized KV cache, actual dimension is head_size / 2 because 2 nibbles are packed into one byte.
// Note that we have checked that head_size is a multiple of 8 in Check_QKV.
int packed_head_size = (kv_cache_bit_width == 4) ? (head_size / 2) : head_size;
if (past_key_dims[3] != packed_head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' dimension 3 should be same as head_size, got ",
past_key_dims[3], " expected ", packed_head_size);
}
if (past_value_dims[3] != packed_head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_value' dimension 3 should be same as head_size, got ",
past_value_dims[3], " expected ", packed_head_size);
// Sentinel: kv_cache_bit_width == -1 means "TurboQuant slot layout, skip last-dim check"
// (TQ packs differently and the last dim is bytes-per-slot, not head_size or head_size/2).
if (kv_cache_bit_width != -1) {
int packed_head_size = (kv_cache_bit_width == 4) ? (head_size / 2) : head_size;
if (past_key_dims[3] != packed_head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' dimension 3 should be same as head_size, got ",
past_key_dims[3], " expected ", packed_head_size);
}
if (past_value_dims[3] != packed_head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_value' dimension 3 should be same as head_size, got ",
past_value_dims[3], " expected ", packed_head_size);
}
}
return Status::OK();
}
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <gsl/gsl>
#include <iostream>
#include "core/framework/allocator.h" // for AllocatorPtr
#include "contrib_ops/cpu/bert/attention_common.h"
#include "contrib_ops/cpu/bert/attention_parameters.h"

Expand Down Expand Up @@ -213,6 +214,13 @@ struct GroupQueryAttentionData {
T* unfused_q_bnsh = nullptr;
T* unfused_y_bnsh = nullptr;
void* unfused_workspace = nullptr;

// TurboQuant graph initializers (only set when parameters.kv_quant_method == TURBOQUANT).
// k_codebook : [2^key_quant_bits] static Lloyd-Max centroids (fp16/bf16 device pointer)
// hadamard : [head_size, head_size] Walsh-Hadamard rotation matrix (fp16/bf16 device pointer)
// Both are read-only graph initializers shipped inside the .onnx model.
const T* k_codebook = nullptr;
const T* hadamard = nullptr;
};

template <typename T>
Expand Down
Loading