Optimize MLAS quantized KV-cache GEMM kernels (follow-up to #28578)#28606
Open
tianleiwu wants to merge 4 commits into
Open
Optimize MLAS quantized KV-cache GEMM kernels (follow-up to #28578)#28606tianleiwu wants to merge 4 commits into
tianleiwu wants to merge 4 commits into
Conversation
Replace _mm512_roundscale_ps + _mm512_cvtps_epi32 with a single _mm512_cvt_roundps_epi32 that combines round-to-nearest-even and float-to-int32 conversion in one instruction, saving a vrndscaleps per loop iteration. Clamp moved before convert (same results since boundary values 0.0/255.0 are already integers).
Perform the 128*sum(b) zero-point correction in int32 before converting to float. This avoids precision loss when |dot_i32| or |128*b_sum_i32| exceed 2^24 (where float32 loses integer precision), preventing potential catastrophic cancellation in the float subtraction. Applied to both VnniDotInt8PerTensor and VnniMultiDot4Int8PerTensor. Overflow is not a concern: for typical K<=16384, the max value is 128 * K * 127 ≈ 264M << INT32_MAX.
For the per-tensor (single-scale) INT8 path, factor out the constant scale multiplication from the inner loop using the distributive property: sum(a[k] * b[k] * scale) = scale * sum(a[k] * b[k]) This saves one vmulps per 8 elements (AVX2) or 16 elements (AVX-512) in the hot loop. The final result is multiplied by scale once after accumulation. Numerically equivalent within FP rounding.
Factor single-scale per-tensor multiplication out of hot loops in the quantized KV-cache GEMM kernels where possible: - QK INT8 fused dot paths defer scale until after accumulation. - SV INT8 per-tensor paths scale the output row once after accumulation. - NEON SV per-tensor dequantization can leave rows unscaled and scale C once. Also clarify AVX2 INT4 nibble extraction and use uint32_t for the raw packed load.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Follow-up performance and correctness improvements to the MLAS quantized KV-cache GEMM kernels introduced in #28578. These changes target the AVX2, AVX512-VNNI, and NEON kernel files only.
Changes
Use embedded rounding in
QuantizeRowToU8(AVX-512)Replace
_mm512_roundscale_ps+_mm512_cvtps_epi32with a single_mm512_cvt_roundps_epi32that combines round-to-nearest-even and float-to-int32 in one instruction, saving avrndscalepsper loop iteration.Use int32 zero-point correction in VNNI dot products
Perform the
dot - 128*sum(b)zero-point correction in int32 before converting to float. This avoids precision loss when operands exceed 2^24 (where float32 loses integer precision), preventing potential catastrophic cancellation.Defer per-tensor scale in
FusedDotInt8(AVX2 + AVX-512)Factor the constant per-tensor scale out of the inner loop:
sum(a*b*s) = s * sum(a*b). Saves onevmulpsper 8/16 elements in the hot path.Defer per-tensor scale in SVGemm and NEON dequantization
SVGemm: accumulate unscaled dot products, multiply the output row by the per-tensor scale once after the K loop.DequantRow_Neonwithapply_per_tensor_scaleto skip per-element scaling during dequantization when using per-tensor mode; scale the output row once after accumulation.uint32_tfor the raw packed load.Motivation
The per-tensor quantization paths were previously applying a constant scale factor on every element inside hot loops. By deferring the scalar multiplication to after accumulation (using the distributive property), we reduce instruction count in the inner loops without changing numerical results (within normal FP reordering tolerance).
The int32 zero-point correction fix addresses a latent precision issue in AVX512-VNNI paths that could manifest at large K dimensions (K > ~512).
Testing
onnxruntime_mlas_test --gtest_filter=KVQuant.*passes (Debug build, x86-64).KVQuant.ShortExecuteexercises all modified code paths across INT8/INT4 per-tensor/per-channel modes.Benchmark Results
Measured on Intel Xeon Platinum 8370C (8 cores, 16 threads, AVX-512 + VNNI), Release build. Each benchmark uses
--benchmark_min_time=0.3s --benchmark_repetitions=5.QKGemm (query × K_cache^T) — INT8 per-tensor (S8_PerTensor, QuantType:0)
This is the path most improved by the deferred-scale optimization (changes 3 and 4).
SVGemm (attn_probs × V_cache) — INT8 per-tensor (S8_PerTensor, QuantType:0)
Other quant types (S8_PerChannel, S4_PerTensor, S4_PerChannel) — neutral
Per-channel and INT4 paths are not affected by the deferred-scale optimization. Representative M=128 results:
Summary: The INT8 per-tensor paths (the most common decode configuration) see 12–36% QKGemm speedup and 4–15% SVGemm speedup at representative shapes. Other quantization modes are neutral within noise (±1–3%).