Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 20 additions & 36 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -931,30 +931,22 @@ Status MatMulNBits<float>::ComputeBUnpacked(const Tensor* a,
"Only 2b and 4b quantization is supported for unpacked compute using "
"non-MLAS de-quantization for now");

// !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!!
// Note: The kernel registration constrains T3 to {uint8_t, T1}, so for
// MatMulNBits<float> only float (not MLFloat16) ZP can reach this branch.
if (zero_points && zero_points->IsDataType<float>()) {
if (nbits_ == 2) {
ORT_ENFORCE(reorder_idx_data == nullptr,
"g_idx (reorder index) is not supported for 2-bit quantization with float zero points");
// Simple 2-bit dequantization with float zero points
const float* float_zp = static_cast<const float*>(zero_points_data);
size_t k_blocks = (K_ + block_size_ - 1) / block_size_;
size_t packed_k = k_blocks * block_size_;
size_t bytes_per_col = packed_k / 4;
for (size_t n = 0; n < N_; n++) {
for (size_t k = 0; k < K_; k++) {
size_t block_idx = k / block_size_;
float scale = scales_data[n * k_blocks + block_idx];
float zp = float_zp[n * k_blocks + block_idx];
size_t packed_idx = n * bytes_per_col + k / 4;
int bit_offset = static_cast<int>((k % 4) * 2);
uint8_t q = (b_data[packed_idx] >> bit_offset) & 0x3;
tmp_b_data_ptr.get()[n * K_ + k] =
(static_cast<float>(q) - zp) * scale;
}
}
DequantizeBlockwise2Bits<float, float>(
Comment thread
tianleiwu marked this conversation as resolved.
tmp_b_data_ptr.get(),
b_data,
scales_data,
static_cast<const float*>(zero_points_data),
static_cast<int32_t>(block_size_),
column_wise_quant_,
static_cast<int32_t>(K_),
static_cast<int32_t>(N_),
thread_pool);
} else {
DequantizeBlockwise<float, float>(
tmp_b_data_ptr.get(), // dequantized output
Expand Down Expand Up @@ -1092,30 +1084,22 @@ Status MatMulNBits<MLFloat16>::ComputeBUnpacked(const Tensor* a,
"Only 2b and 4b quantization is supported for unpacked compute using "
"non-MLAS de-quantization for now");

// !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!!
// Note: The kernel registration constrains T3 to {uint8_t, T1}, so for
// MatMulNBits<MLFloat16> only MLFloat16 (not float) ZP can reach this branch.
if (zero_points && zero_points->IsDataType<MLFloat16>()) {
if (nbits_ == 2) {
ORT_ENFORCE(reorder_idx_data == nullptr,
"g_idx (reorder index) is not supported for 2-bit quantization with float zero points");
// Simple 2-bit dequantization with MLFloat16 zero points
const MLFloat16* fp16_zp = static_cast<const MLFloat16*>(zero_points_data);
size_t k_blocks = (K_ + block_size_ - 1) / block_size_;
size_t packed_k = k_blocks * block_size_;
size_t bytes_per_col = packed_k / 4;
for (size_t n = 0; n < N_; n++) {
for (size_t k = 0; k < K_; k++) {
size_t block_idx = k / block_size_;
float scale = scales_ptr[n * k_blocks + block_idx];
float zp = fp16_zp[n * k_blocks + block_idx].ToFloat();
size_t packed_idx = n * bytes_per_col + k / 4;
int bit_offset = static_cast<int>((k % 4) * 2);
uint8_t q = (b_data[packed_idx] >> bit_offset) & 0x3;
tmp_b_data_ptr.get()[n * K_ + k] =
(static_cast<float>(q) - zp) * scale;
}
}
DequantizeBlockwise2Bits<float, MLFloat16>(
tmp_b_data_ptr.get(),
b_data,
scales_ptr,
static_cast<const MLFloat16*>(zero_points_data),
static_cast<int32_t>(block_size_),
column_wise_quant_,
static_cast<int32_t>(K_),
static_cast<int32_t>(N_),
thread_pool);
} else {
DequantizeBlockwise<float, MLFloat16>(
tmp_b_data_ptr.get(), // dequantized output
Expand Down
106 changes: 106 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,111 @@
const MLFloat16* zero_points, const int32_t* reorder_idx, int32_t block_size,
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);

// 2-bit dequantization kernel for float/MLFloat16 zero points.
// Processes 16 elements at a time (16 x 2-bit = 32 bits = one uint32_t).
// Layout: columnwise packing — elements within a column are packed consecutively,
// output[n * K + k] = (quant_value - zp) * scale
template <class T, class zeroT>
void Dequantize2BitsKernel(
T* output, const uint8_t* quant_data, const T* scale_data,
const zeroT* zero_points, int block_size,
int groups_per_threadblock, int total_groups, int N, int K,
int blockIdx_x, int threadIdx_x) {
// Each "thread" handles 16 elements (one uint32 of packed 2-bit values)
constexpr int elements_per_thread = 16;
const int group_id = blockIdx_x * groups_per_threadblock + ((threadIdx_x * elements_per_thread) / block_size);
if (group_id >= total_groups) {
return;
}
const int k_blocks = (K + block_size - 1) / block_size;

int n_idx = group_id / k_blocks;
int kb_idx = group_id % k_blocks;
int element_offset = group_id * block_size + ((threadIdx_x * elements_per_thread) & (block_size - 1));

const int k_offset = element_offset % (k_blocks * block_size);
const int n_offset = element_offset / (k_blocks * block_size);
if (n_offset >= N || k_offset >= K) {
return;
}

T* output_i = output + n_offset * K + k_offset;
// 16 elements × 2 bits = 4 bytes
uint32_t quant_value = *(reinterpret_cast<const uint32_t*>(quant_data + element_offset / 4));
if constexpr (onnxruntime::endian::native == onnxruntime::endian::big) {
const uint8_t* c = (const uint8_t*)(&quant_value);
quant_value = (uint32_t)c[0] |

Check warning on line 153 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use static_cast<uint32_t>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc:153: Using C-style cast. Use static_cast<uint32_t>(...) instead [readability/casting] [4]
(uint32_t)c[1] << 8 |

Check warning on line 154 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use static_cast<uint32_t>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc:154: Using C-style cast. Use static_cast<uint32_t>(...) instead [readability/casting] [4]
(uint32_t)c[2] << 16 |

Check warning on line 155 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use static_cast<uint32_t>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc:155: Using C-style cast. Use static_cast<uint32_t>(...) instead [readability/casting] [4]
(uint32_t)c[3] << 24;

Check warning on line 156 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use static_cast<uint32_t>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc:156: Using C-style cast. Use static_cast<uint32_t>(...) instead [readability/casting] [4]
}
const int remain_k = std::min(elements_per_thread, K - k_offset);

T scale = *(scale_data + static_cast<uint64_t>(n_idx) * static_cast<uint64_t>(k_blocks) + static_cast<uint64_t>(kb_idx));
float zp_f = 0.0f;
if (zero_points) {
if constexpr (std::is_same_v<zeroT, MLFloat16>) {
zp_f = (*(zero_points + static_cast<uint64_t>(n_idx) * static_cast<uint64_t>(k_blocks) + static_cast<uint64_t>(kb_idx))).ToFloat();
} else {
zp_f = static_cast<float>(*(zero_points + static_cast<uint64_t>(n_idx) * static_cast<uint64_t>(k_blocks) + static_cast<uint64_t>(kb_idx)));
}
}

if constexpr (std::is_same_v<T, MLFloat16>) {
T zp_adjust = -scale * MLFloat16(zp_f);
for (int i = 0; i < remain_k; i++) {
output_i[i] = static_cast<float>((quant_value >> (2 * i)) & 0x3) * scale + zp_adjust;
}
} else {
T zp_adjust = -scale * zp_f;
for (int i = 0; i < remain_k; i++) {
output_i[i] = T((quant_value >> (2 * i)) & 0x3) * scale + zp_adjust;
}
}
}

// Specialization of DequantizeBlockwise for qbits=2
template <typename inputT, typename zeroT>
void DequantizeBlockwise2Bits(
inputT* output,
const uint8_t* quant_data,
const inputT* scales_data,
const zeroT* zero_points,
int32_t block_size,
bool,
int32_t K,
int32_t N,
onnxruntime::concurrency::ThreadPool* pool) {
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
constexpr int elements_per_thread = 16;
ORT_ENFORCE(block_size > 0 && block_size <= 256 * elements_per_thread && block_size % elements_per_thread == 0,
"block_size must be positive, at most ", 256 * elements_per_thread,
", and a multiple of ", elements_per_thread, ", got: ", block_size);
int groups_per_threadblock = 256 * elements_per_thread / block_size;
int groups_per_K = ceildiv(K, block_size);
int total_groups = N * groups_per_K;
int blocks_per_grid = static_cast<int>(ceildiv(total_groups, groups_per_threadblock));
concurrency::ThreadPool::TrySimpleParallelFor(
pool, static_cast<std::ptrdiff_t>(blocks_per_grid),
[&](std::ptrdiff_t block_id) {
for (int j = 0; j < 256; j++) {
Comment thread
tianleiwu marked this conversation as resolved.
Dequantize2BitsKernel(output, quant_data, scales_data, zero_points,
block_size, groups_per_threadblock,
total_groups, N, K, static_cast<int>(block_id), j);
}
});
}

// Explicit instantiations for 2-bit dequantization
template void DequantizeBlockwise2Bits<float, float>(
float* output, const uint8_t* quant_data, const float* scales_data,
const float* zero_points, int32_t block_size,
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);

template void DequantizeBlockwise2Bits<float, MLFloat16>(
float* output, const uint8_t* quant_data, const float* scales_data,
const MLFloat16* zero_points, int32_t block_size,
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);

} // namespace contrib
} // namespace onnxruntime
14 changes: 14 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,19 @@ void DequantizeBlockwise(
int32_t N, // number of columns in quantized input
onnxruntime::concurrency::ThreadPool* thread_pool);

// Threaded 2-bit blockwise dequantization with float/MLFloat16 zero points.
// Does not support reorder_idx (g_idx).
template <typename inputT, typename zeroT>
void DequantizeBlockwise2Bits(
inputT* output, // dequantized output
const uint8_t* quant_data, // quantized input
const inputT* scales_data, // quantization scales
const zeroT* zero_points, // quantization zero points
int32_t block_size, // quantization block size
bool, // columnwise quantization or row-wise
int32_t K, // number of rows in quantized input
int32_t N, // number of columns in quantized input
onnxruntime::concurrency::ThreadPool* thread_pool);

} // namespace contrib
} // namespace onnxruntime
Loading
Loading