Skip to content

Optimize MLAS quantized KV-cache GEMM kernels (follow-up to #28578)#28606

Open
tianleiwu wants to merge 4 commits into
microsoft:mainfrom
tianleiwu:tlwu/20260520/gqa_cpu_quantized_kv_mlas_opt
Open

Optimize MLAS quantized KV-cache GEMM kernels (follow-up to #28578)#28606
tianleiwu wants to merge 4 commits into
microsoft:mainfrom
tianleiwu:tlwu/20260520/gqa_cpu_quantized_kv_mlas_opt

Conversation

@tianleiwu
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu commented May 21, 2026

Description

Follow-up performance and correctness improvements to the MLAS quantized KV-cache GEMM kernels introduced in #28578. These changes target the AVX2, AVX512-VNNI, and NEON kernel files only.

Changes

  1. Use embedded rounding in QuantizeRowToU8 (AVX-512)
    Replace _mm512_roundscale_ps + _mm512_cvtps_epi32 with a single _mm512_cvt_roundps_epi32 that combines round-to-nearest-even and float-to-int32 in one instruction, saving a vrndscaleps per loop iteration.

  2. Use int32 zero-point correction in VNNI dot products
    Perform the dot - 128*sum(b) zero-point correction in int32 before converting to float. This avoids precision loss when operands exceed 2^24 (where float32 loses integer precision), preventing potential catastrophic cancellation.

  3. Defer per-tensor scale in FusedDotInt8 (AVX2 + AVX-512)
    Factor the constant per-tensor scale out of the inner loop: sum(a*b*s) = s * sum(a*b). Saves one vmulps per 8/16 elements in the hot path.

  4. Defer per-tensor scale in SVGemm and NEON dequantization

    • AVX2/AVX-512 SVGemm: accumulate unscaled dot products, multiply the output row by the per-tensor scale once after the K loop.
    • NEON: parameterize DequantRow_Neon with apply_per_tensor_scale to skip per-element scaling during dequantization when using per-tensor mode; scale the output row once after accumulation.
    • Also: clarify AVX2 INT4 nibble extraction comment and use uint32_t for the raw packed load.

Motivation

The per-tensor quantization paths were previously applying a constant scale factor on every element inside hot loops. By deferring the scalar multiplication to after accumulation (using the distributive property), we reduce instruction count in the inner loops without changing numerical results (within normal FP reordering tolerance).

The int32 zero-point correction fix addresses a latent precision issue in AVX512-VNNI paths that could manifest at large K dimensions (K > ~512).

Testing

  • onnxruntime_mlas_test --gtest_filter=KVQuant.* passes (Debug build, x86-64).
  • No new tests needed — existing KVQuant.ShortExecute exercises all modified code paths across INT8/INT4 per-tensor/per-channel modes.

Benchmark Results

Measured on Intel Xeon Platinum 8370C (8 cores, 16 threads, AVX-512 + VNNI), Release build. Each benchmark uses --benchmark_min_time=0.3s --benchmark_repetitions=5.

QKGemm (query × K_cache^T) — INT8 per-tensor (S8_PerTensor, QuantType:0)

This is the path most improved by the deferred-scale optimization (changes 3 and 4).

Shape Before (ns) After (ns) Speedup
M=1, N=512, K=64 2,926 2,803 1.04x
M=1, N=512, K=128 5,914 5,074 1.17x
M=1, N=2048, K=128 22,401 19,937 1.12x
M=128, N=512, K=64 412,505 304,230 1.36x
M=128, N=512, K=128 911,508 788,198 1.16x
M=128, N=2048, K=64 1,662,547 1,242,441 1.34x
M=128, N=2048, K=128 3,660,599 3,176,911 1.15x

SVGemm (attn_probs × V_cache) — INT8 per-tensor (S8_PerTensor, QuantType:0)

Shape Before (ns) After (ns) Speedup
M=1, N=64, K=512 4,707 4,122 1.14x
M=1, N=64, K=2048 18,516 16,533 1.12x
M=128, N=64, K=512 399,703 358,821 1.11x
M=128, N=64, K=2048 1,633,807 1,423,984 1.15x
M=128, N=128, K=512 775,205 761,527 1.02x
M=128, N=128, K=2048 3,086,642 2,979,566 1.04x

Other quant types (S8_PerChannel, S4_PerTensor, S4_PerChannel) — neutral

Per-channel and INT4 paths are not affected by the deferred-scale optimization. Representative M=128 results:

Benchmark QuantType Before (ns) After (ns) Ratio
QKGemm M=128, N=2048, K=128 S8_PerChannel 4,555,381 4,684,954 0.97x
QKGemm M=128, N=2048, K=128 S4_PerTensor 3,841,759 3,819,387 1.01x
QKGemm M=128, N=2048, K=128 S4_PerChannel 4,043,262 4,056,033 1.00x
SVGemm M=128, N=128, K=2048 S8_PerChannel 4,449,839 4,290,344 1.04x
SVGemm M=128, N=128, K=2048 S4_PerTensor 2,989,684 2,998,154 1.00x
SVGemm M=128, N=128, K=2048 S4_PerChannel 3,403,497 3,390,452 1.00x

Summary: The INT8 per-tensor paths (the most common decode configuration) see 12–36% QKGemm speedup and 4–15% SVGemm speedup at representative shapes. Other quantization modes are neutral within noise (±1–3%).

tianleiwu added 4 commits May 21, 2026 04:59
Replace _mm512_roundscale_ps + _mm512_cvtps_epi32 with a single
_mm512_cvt_roundps_epi32 that combines round-to-nearest-even and
float-to-int32 conversion in one instruction, saving a vrndscaleps
per loop iteration. Clamp moved before convert (same results since
boundary values 0.0/255.0 are already integers).
Perform the 128*sum(b) zero-point correction in int32 before converting
to float. This avoids precision loss when |dot_i32| or |128*b_sum_i32|
exceed 2^24 (where float32 loses integer precision), preventing
potential catastrophic cancellation in the float subtraction.

Applied to both VnniDotInt8PerTensor and VnniMultiDot4Int8PerTensor.
Overflow is not a concern: for typical K<=16384, the max value is
128 * K * 127 ≈ 264M << INT32_MAX.
For the per-tensor (single-scale) INT8 path, factor out the constant
scale multiplication from the inner loop using the distributive property:
  sum(a[k] * b[k] * scale) = scale * sum(a[k] * b[k])

This saves one vmulps per 8 elements (AVX2) or 16 elements (AVX-512)
in the hot loop. The final result is multiplied by scale once after
accumulation. Numerically equivalent within FP rounding.
Factor single-scale per-tensor multiplication out of hot loops in the
quantized KV-cache GEMM kernels where possible:
- QK INT8 fused dot paths defer scale until after accumulation.
- SV INT8 per-tensor paths scale the output row once after accumulation.
- NEON SV per-tensor dequantization can leave rows unscaled and scale C once.

Also clarify AVX2 INT4 nibble extraction and use uint32_t for the raw packed load.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant