From f39f2b885bb5bf92eb587a3c86d7c3cfc141b176 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 21 May 2026 01:20:41 +0000 Subject: [PATCH 1/5] perf: use embedded rounding in QuantizeRowToU8 Replace _mm512_roundscale_ps + _mm512_cvtps_epi32 with a single _mm512_cvt_roundps_epi32 that combines round-to-nearest-even and float-to-int32 conversion in one instruction, saving a vrndscaleps per loop iteration. Clamp moved before convert (same results since boundary values 0.0/255.0 are already integers). --- onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp index ac23a0703ddff..bb9d8e558a498 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp @@ -83,12 +83,13 @@ QuantizeRowToU8(const float* src, uint8_t* dst, size_t len) i = 0; for (; i < vec_end; i += 16) { __m512 v = _mm512_loadu_ps(src + i); - // q = round(v * inv_scale) + 128, clamped to [0, 255] + // q = (v * inv_scale) + 128, clamped to [0, 255] __m512 scaled = _mm512_fmadd_ps(v, inv_scale_vec, zp_vec); - scaled = _mm512_roundscale_ps(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); scaled = _mm512_max_ps(scaled, min_val); scaled = _mm512_min_ps(scaled, max_clamp); - __m512i qi = _mm512_cvtps_epi32(scaled); + // Round-to-nearest-even and convert to int32 in a single instruction + // (AVX-512 embedded rounding eliminates a separate vrndscaleps). + __m512i qi = _mm512_cvt_roundps_epi32(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); // Pack 16 int32 -> 16 uint8 __m128i packed = _mm512_cvtepi32_epi8(qi); _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + i), packed); From 76058f84a8d2e635b81658b279355ad28fd3d242 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 21 May 2026 01:25:12 +0000 Subject: [PATCH 2/5] fix: use int32 zero-point correction in VNNI dot products MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Perform the 128*sum(b) zero-point correction in int32 before converting to float. This avoids precision loss when |dot_i32| or |128*b_sum_i32| exceed 2^24 (where float32 loses integer precision), preventing potential catastrophic cancellation in the float subtraction. Applied to both VnniDotInt8PerTensor and VnniMultiDot4Int8PerTensor. Overflow is not a concern: for typical K<=16384, the max value is 128 * K * 127 ≈ 264M << INT32_MAX. --- .../mlas/lib/qkv_quant_kernel_avx512vnni.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp index bb9d8e558a498..49e1b73a007b8 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp @@ -170,9 +170,11 @@ VnniDotInt8PerTensor( // Correction: dpbusd computed sum(a_u8 * b_s8). // We want sum((a_u8 - 128) * b_s8) = sum(a_u8 * b_s8) - 128 * sum(b_s8) - float corrected = static_cast(dot_i32) - 128.0f * static_cast(b_sum_i32); + // Perform correction in int32 to preserve precision (avoids float rounding + // when |dot_i32| or |128*b_sum_i32| exceed 2^24). + int32_t corrected = dot_i32 - (128 * b_sum_i32); - return corrected * scale_a * scale_b; + return static_cast(corrected) * scale_a * scale_b; } // @@ -403,11 +405,11 @@ VnniMultiDot4Int8PerTensor( bs[3] += static_cast(b3[k]); } - const float zp = 128.0f; - out[0] = (static_cast(dot[0]) - zp * static_cast(bs[0])) * combined_scale; - out[1] = (static_cast(dot[1]) - zp * static_cast(bs[1])) * combined_scale; - out[2] = (static_cast(dot[2]) - zp * static_cast(bs[2])) * combined_scale; - out[3] = (static_cast(dot[3]) - zp * static_cast(bs[3])) * combined_scale; + // Zero-point correction in int32 for precision (see VnniDotInt8PerTensor). + out[0] = static_cast(dot[0] - 128 * bs[0]) * combined_scale; + out[1] = static_cast(dot[1] - 128 * bs[1]) * combined_scale; + out[2] = static_cast(dot[2] - 128 * bs[2]) * combined_scale; + out[3] = static_cast(dot[3] - 128 * bs[3]) * combined_scale; } // ============================================================================ From 14cdf174adc11edc1cf24c6def557f28028aa671 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 21 May 2026 01:30:21 +0000 Subject: [PATCH 3/5] perf: defer per-tensor scale to after accumulation in FusedDotInt8 For the per-tensor (single-scale) INT8 path, factor out the constant scale multiplication from the inner loop using the distributive property: sum(a[k] * b[k] * scale) = scale * sum(a[k] * b[k]) This saves one vmulps per 8 elements (AVX2) or 16 elements (AVX-512) in the hot loop. The final result is multiplied by scale once after accumulation. Numerically equivalent within FP rounding. --- .../core/mlas/lib/qkv_quant_kernel_avx2.cpp | 19 ++++++++++++------- .../mlas/lib/qkv_quant_kernel_avx512vnni.cpp | 19 ++++++++++++------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp index d3681ff6bfdff..95985ac37f371 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp @@ -126,19 +126,19 @@ FusedDotInt8( acc0 = _mm256_fmadd_ps(a0, bf0, acc0); } } else { - __m256 scale_vec = _mm256_broadcast_ss(scales); + // Per-tensor: defer scale multiplication until after accumulation. + // sum(a[k] * b[k] * scale) = scale * sum(a[k] * b[k]) + // This saves one vmulps per 8 elements in the hot loop. for (; k < vec_end; k += 16) { __m128i raw0 = _mm_loadl_epi64(reinterpret_cast(b_row + k)); __m256i i32_0 = _mm256_cvtepi8_epi32(raw0); __m256 bf0 = _mm256_cvtepi32_ps(i32_0); - bf0 = _mm256_mul_ps(bf0, scale_vec); __m256 a0 = _mm256_loadu_ps(a_row + k); acc0 = _mm256_fmadd_ps(a0, bf0, acc0); __m128i raw1 = _mm_loadl_epi64(reinterpret_cast(b_row + k + 8)); __m256i i32_1 = _mm256_cvtepi8_epi32(raw1); __m256 bf1 = _mm256_cvtepi32_ps(i32_1); - bf1 = _mm256_mul_ps(bf1, scale_vec); __m256 a1 = _mm256_loadu_ps(a_row + k + 8); acc1 = _mm256_fmadd_ps(a1, bf1, acc1); } @@ -146,7 +146,6 @@ FusedDotInt8( __m128i raw0 = _mm_loadl_epi64(reinterpret_cast(b_row + k)); __m256i i32_0 = _mm256_cvtepi8_epi32(raw0); __m256 bf0 = _mm256_cvtepi32_ps(i32_0); - bf0 = _mm256_mul_ps(bf0, scale_vec); __m256 a0 = _mm256_loadu_ps(a_row + k); acc0 = _mm256_fmadd_ps(a0, bf0, acc0); } @@ -161,9 +160,15 @@ FusedDotInt8( float dot = _mm_cvtss_f32(sum4); // Scalar tail - for (; k < K; ++k) { - float sc = per_channel ? scales[k] : scales[0]; - dot += a_row[k] * static_cast(b_row[k]) * sc; + if (per_channel) { + for (; k < K; ++k) { + dot += a_row[k] * static_cast(b_row[k]) * scales[k]; + } + } else { + for (; k < K; ++k) { + dot += a_row[k] * static_cast(b_row[k]); + } + dot *= scales[0]; } return dot; } diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp index 49e1b73a007b8..5f614bd22e8f4 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp @@ -224,19 +224,19 @@ FusedDotInt8_Avx512( acc0 = _mm512_fmadd_ps(a0, bf0, acc0); } } else { - __m512 scale_vec = _mm512_set1_ps(scales[0]); + // Per-tensor: defer scale multiplication until after accumulation. + // sum(a[k] * b[k] * scale) = scale * sum(a[k] * b[k]) + // This saves one vmulps per 16 elements in the hot loop. for (; k < vec_end; k += 32) { __m128i raw0 = _mm_loadu_si128(reinterpret_cast(b_row + k)); __m512i i32_0 = _mm512_cvtepi8_epi32(raw0); __m512 bf0 = _mm512_cvtepi32_ps(i32_0); - bf0 = _mm512_mul_ps(bf0, scale_vec); __m512 a0 = _mm512_loadu_ps(a_row + k); acc0 = _mm512_fmadd_ps(a0, bf0, acc0); __m128i raw1 = _mm_loadu_si128(reinterpret_cast(b_row + k + 16)); __m512i i32_1 = _mm512_cvtepi8_epi32(raw1); __m512 bf1 = _mm512_cvtepi32_ps(i32_1); - bf1 = _mm512_mul_ps(bf1, scale_vec); __m512 a1 = _mm512_loadu_ps(a_row + k + 16); acc1 = _mm512_fmadd_ps(a1, bf1, acc1); } @@ -244,7 +244,6 @@ FusedDotInt8_Avx512( __m128i raw0 = _mm_loadu_si128(reinterpret_cast(b_row + k)); __m512i i32_0 = _mm512_cvtepi8_epi32(raw0); __m512 bf0 = _mm512_cvtepi32_ps(i32_0); - bf0 = _mm512_mul_ps(bf0, scale_vec); __m512 a0 = _mm512_loadu_ps(a_row + k); acc0 = _mm512_fmadd_ps(a0, bf0, acc0); } @@ -254,9 +253,15 @@ FusedDotInt8_Avx512( float dot = _mm512_reduce_add_ps(acc0); // Scalar tail - for (; k < K; ++k) { - float sc = per_channel ? scales[k] : scales[0]; - dot += a_row[k] * static_cast(b_row[k]) * sc; + if (per_channel) { + for (; k < K; ++k) { + dot += a_row[k] * static_cast(b_row[k]) * scales[k]; + } + } else { + for (; k < K; ++k) { + dot += a_row[k] * static_cast(b_row[k]); + } + dot *= scales[0]; } return dot; } From b2bd847036006d8d97b29c4fd8dd03587e0d4576 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 21 May 2026 01:56:12 +0000 Subject: [PATCH 4/5] perf: defer per-tensor scale in KV quant kernels Factor single-scale per-tensor multiplication out of hot loops in the quantized KV-cache GEMM kernels where possible: - QK INT8 fused dot paths defer scale until after accumulation. - SV INT8 per-tensor paths scale the output row once after accumulation. - NEON SV per-tensor dequantization can leave rows unscaled and scale C once. Also clarify AVX2 INT4 nibble extraction and use uint32_t for the raw packed load. --- .../core/mlas/lib/qkv_quant_kernel_avx2.cpp | 24 +++++--- .../mlas/lib/qkv_quant_kernel_avx512vnni.cpp | 16 +++++- .../core/mlas/lib/qkv_quant_kernel_neon.cpp | 57 ++++++++++++++----- 3 files changed, 73 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp index 95985ac37f371..8bec2d350afa5 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp @@ -42,16 +42,16 @@ DequantInt4x8(const uint8_t* src, size_t col, bool per_channel, const float* sca // Load 4 packed bytes safely without strict-aliasing / alignment UB. // Compilers optimize memcpy of 4 bytes to a single mov instruction. - int raw_bytes; + uint32_t raw_bytes; std::memcpy(&raw_bytes, base, sizeof(raw_bytes)); - __m128i packed = _mm_cvtsi32_si128(raw_bytes); + __m128i packed = _mm_cvtsi32_si128(static_cast(raw_bytes)); // Low nibbles (even columns): AND with 0x0F __m128i lo_mask = _mm_set1_epi8(0x0F); __m128i lo = _mm_and_si128(packed, lo_mask); - // High nibbles (odd columns): shift right 4 using 32-bit granularity - // to prevent bit bleeding across 16-bit lane boundaries, then mask. + // High nibbles (odd columns): shift right by 4 within 32-bit lanes, then mask. + // Any cross-byte bits from the shift land in the upper nibble and are discarded by the mask. __m128i hi = _mm_and_si128(_mm_srli_epi32(packed, 4), lo_mask); // Interleave low and high nibbles: [lo0,hi0, lo1,hi1, lo2,hi2, lo3,hi3] @@ -331,7 +331,7 @@ SVGemm_Avx2( } } } else { - __m256 scale_vec = _mm256_broadcast_ss(Scales); + // Per-tensor: accumulate unscaled dot products, then scale the output row once. for (size_t k = 0; k < K; ++k) { const int8_t* b_row = reinterpret_cast(B_bytes + k * row_bytes); const float a_val = a_row[k]; @@ -342,15 +342,25 @@ SVGemm_Avx2( __m128i raw = _mm_loadl_epi64(reinterpret_cast(b_row + n)); __m256i i32 = _mm256_cvtepi8_epi32(raw); __m256 bf = _mm256_cvtepi32_ps(i32); - bf = _mm256_mul_ps(bf, scale_vec); __m256 c_vec = _mm256_loadu_ps(c_row + n); c_vec = _mm256_fmadd_ps(a_broadcast, bf, c_vec); _mm256_storeu_ps(c_row + n, c_vec); } for (; n < N; ++n) { - c_row[n] += a_val * static_cast(b_row[n]) * Scales[0]; + c_row[n] += a_val * static_cast(b_row[n]); } } + + __m256 scale_vec = _mm256_broadcast_ss(Scales); + n = 0; + for (; n < vec_end_n; n += 8) { + __m256 c_vec = _mm256_loadu_ps(c_row + n); + c_vec = _mm256_mul_ps(c_vec, scale_vec); + _mm256_storeu_ps(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] *= Scales[0]; + } } } else { // INT4 fused path diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp index 5f614bd22e8f4..8d3636d461e19 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp @@ -577,7 +577,7 @@ SVGemm_Avx512Vnni( } } } else { - __m512 scale_vec = _mm512_set1_ps(Scales[0]); + // Per-tensor: accumulate unscaled dot products, then scale the output row once. for (size_t k = 0; k < K; ++k) { const int8_t* b_row = reinterpret_cast(B_bytes + k * row_bytes); const float a_val = a_row[k]; @@ -588,15 +588,25 @@ SVGemm_Avx512Vnni( __m128i raw = _mm_loadu_si128(reinterpret_cast(b_row + n)); __m512i i32 = _mm512_cvtepi8_epi32(raw); __m512 bf = _mm512_cvtepi32_ps(i32); - bf = _mm512_mul_ps(bf, scale_vec); __m512 c_vec = _mm512_loadu_ps(c_row + n); c_vec = _mm512_fmadd_ps(a_broadcast, bf, c_vec); _mm512_storeu_ps(c_row + n, c_vec); } for (; n < N; ++n) { - c_row[n] += a_val * static_cast(b_row[n]) * Scales[0]; + c_row[n] += a_val * static_cast(b_row[n]); } } + + __m512 scale_vec = _mm512_set1_ps(Scales[0]); + n = 0; + for (; n < vec_end_n; n += 16) { + __m512 c_vec = _mm512_loadu_ps(c_row + n); + c_vec = _mm512_mul_ps(c_vec, scale_vec); + _mm512_storeu_ps(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] *= Scales[0]; + } } } else { // INT4 path: 512-bit wide diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp index ae5a56028bbf9..1aabbd8ca39cb 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp @@ -29,12 +29,13 @@ using namespace MlasKVQuantInternal; namespace { // -// Dequantize 8 INT8 values starting at `col` and scale them. +// Dequantize 8 INT8 values starting at `col`. +// Per-channel rows are always scaled. Per-tensor rows may defer scaling. // Produces two float32x4_t (8 floats total) stored into dst. // inline void DequantInt8x8_Neon(const int8_t* src, size_t col, bool per_channel, - const float* scales, float* dst) + const float* scales, bool apply_per_tensor_scale, float* dst) { // Load 8 int8 values int8x8_t raw = vld1_s8(src + col); @@ -52,7 +53,7 @@ DequantInt8x8_Neon(const int8_t* src, size_t col, bool per_channel, float32x4_t sc_hi = vld1q_f32(scales + col + 4); f_lo = vmulq_f32(f_lo, sc_lo); f_hi = vmulq_f32(f_hi, sc_hi); - } else { + } else if (apply_per_tensor_scale) { float32x4_t sc = vdupq_n_f32(scales[0]); f_lo = vmulq_f32(f_lo, sc); f_hi = vmulq_f32(f_hi, sc); @@ -64,10 +65,11 @@ DequantInt8x8_Neon(const int8_t* src, size_t col, bool per_channel, // // Dequantize 8 INT4 values (4 packed bytes) starting at even column `col`. +// Per-channel rows are always scaled. Per-tensor rows may defer scaling. // inline void DequantInt4x8_Neon(const uint8_t* src, size_t col, bool per_channel, - const float* scales, float* dst) + const float* scales, bool apply_per_tensor_scale, float* dst) { const uint8_t* base = src + col / 2; @@ -94,7 +96,7 @@ DequantInt4x8_Neon(const uint8_t* src, size_t col, bool per_channel, float32x4_t sc_hi = vld1q_f32(scales + col + 4); f_lo = vmulq_f32(f_lo, sc_lo); f_hi = vmulq_f32(f_hi, sc_hi); - } else { + } else if (apply_per_tensor_scale) { float32x4_t sc = vdupq_n_f32(scales[0]); f_lo = vmulq_f32(f_lo, sc); f_hi = vmulq_f32(f_hi, sc); @@ -106,6 +108,8 @@ DequantInt4x8_Neon(const uint8_t* src, size_t col, bool per_channel, // // Dequantize one row of length `cols` from packed quantized buffer into `dst`. +// `apply_per_tensor_scale=false` leaves per-tensor rows unscaled so callers can +// factor the single scale out of an outer accumulation loop. // void DequantRow_Neon( @@ -113,7 +117,8 @@ DequantRow_Neon( float* dst, size_t cols, MLAS_KV_QUANT_TYPE qt, - const float* scales) + const float* scales, + bool apply_per_tensor_scale) { const bool int4 = IsInt4Mode(qt); const bool per_channel = IsPerChannelMode(qt); @@ -124,22 +129,32 @@ DequantRow_Neon( if (!int4) { const auto* src = static_cast(src_raw); for (; c < vec_end; c += 8) { - DequantInt8x8_Neon(src, c, per_channel, scales, dst + c); + DequantInt8x8_Neon(src, c, per_channel, scales, apply_per_tensor_scale, dst + c); } for (; c < cols; ++c) { - float sc = per_channel ? scales[c] : scales[0]; - dst[c] = static_cast(src[c]) * sc; + if (per_channel) { + dst[c] = static_cast(src[c]) * scales[c]; + } else if (apply_per_tensor_scale) { + dst[c] = static_cast(src[c]) * scales[0]; + } else { + dst[c] = static_cast(src[c]); + } } } else { const auto* src = static_cast(src_raw); for (; c < vec_end; c += 8) { - DequantInt4x8_Neon(src, c, per_channel, scales, dst + c); + DequantInt4x8_Neon(src, c, per_channel, scales, apply_per_tensor_scale, dst + c); } for (; c < cols; ++c) { uint8_t packed = src[c / 2]; int nibble = (c & 1) == 0 ? (packed & 0x0F) : ((packed >> 4) & 0x0F); - float sc = per_channel ? scales[c] : scales[0]; - dst[c] = static_cast(nibble - kInt4Bias) * sc; + if (per_channel) { + dst[c] = static_cast(nibble - kInt4Bias) * scales[c]; + } else if (apply_per_tensor_scale) { + dst[c] = static_cast(nibble - kInt4Bias) * scales[0]; + } else { + dst[c] = static_cast(nibble - kInt4Bias); + } } } } @@ -174,7 +189,7 @@ QKGemm_Neon( for (size_t n = 0; n < N; ++n) { const uint8_t* b_row = B_bytes + n * row_bytes; - DequantRow_Neon(b_row, b_buf, K, QuantType, Scales); + DequantRow_Neon(b_row, b_buf, K, QuantType, Scales, true); for (size_t m = 0; m < M; ++m) { const float* a_row = A + m * lda; @@ -246,6 +261,7 @@ SVGemm_Neon( { const size_t row_bytes = MlasKVQuantPackedRowBytes(QuantType, N); const auto* B_bytes = static_cast(B); + const bool per_channel = IsPerChannelMode(QuantType); float b_stack[256]; float* b_buf = b_stack; @@ -272,7 +288,7 @@ SVGemm_Neon( for (size_t k = 0; k < K; ++k) { const uint8_t* b_row_packed = B_bytes + k * row_bytes; - DequantRow_Neon(b_row_packed, b_buf, N, QuantType, Scales); + DequantRow_Neon(b_row_packed, b_buf, N, QuantType, Scales, per_channel); const float a_val = a_row[k]; float32x4_t a_broadcast = vdupq_n_f32(a_val); @@ -288,6 +304,19 @@ SVGemm_Neon( c_row[n] += a_val * b_buf[n]; } } + + if (!per_channel) { + const float32x4_t scale_vec = vdupq_n_f32(Scales[0]); + n = 0; + for (; n < vec_end_n; n += 4) { + float32x4_t c_vec = vld1q_f32(c_row + n); + c_vec = vmulq_f32(c_vec, scale_vec); + vst1q_f32(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] *= Scales[0]; + } + } } } From 27782038e4ea4b8103facadc8919ab1efffee9e3 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 21 May 2026 14:59:25 -0700 Subject: [PATCH 5/5] fix: use nearbyintf in scalar tail to match AVX-512 round-to-nearest-even The scalar tail in QuantizeRowToU8 used std::round (ties away from zero) while the vectorized path uses _mm512_cvt_roundps_epi32 with round-to- nearest-even semantics. Switch to std::nearbyintf which respects the default FP rounding mode (round-to-nearest-even on x86). --- onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp index 8d3636d461e19..fa5aff0165897 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp @@ -94,9 +94,10 @@ QuantizeRowToU8(const float* src, uint8_t* dst, size_t len) __m128i packed = _mm512_cvtepi32_epi8(qi); _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + i), packed); } - // Scalar tail + // Scalar tail (use nearbyintf for round-to-nearest-even, matching the + // AVX-512 embedded rounding semantics above). for (; i < len; ++i) { - float q = std::round(src[i] * inv_scale) + 128.0f; + float q = std::nearbyintf(src[i] * inv_scale) + 128.0f; q = std::max(0.0f, std::min(255.0f, q)); dst[i] = static_cast(q); }