Skip to content

Optimize MatMulNBits 2-bit + float zero_point CPU dequantization with multi-threaded kernel#28589

Open
Copilot wants to merge 4 commits into
mainfrom
copilot/optimize-matmulnbits-performance
Open

Optimize MatMulNBits 2-bit + float zero_point CPU dequantization with multi-threaded kernel#28589
Copilot wants to merge 4 commits into
mainfrom
copilot/optimize-matmulnbits-performance

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented May 20, 2026

Description

Replace the naive single-threaded scalar loop for 2-bit dequantization with float/MLFloat16 zero points with a multi-threaded kernel (DequantizeBlockwise2Bits) that:

  • Parallelizes via TrySimpleParallelFor — distributes work across all intra-op threads (previously single-threaded)
  • Processes 16 elements per iteration — one uint32_t = 16 packed 2-bit values, reducing per-element overhead
  • Hoists scale/zp lookups — all 16 elements share a block, so scale and zero_point are loaded once per batch

Follows the same threading pattern as the existing 4-bit DequantizeBlockwise path for consistency.

Files changed:

  • matmul_nbits_impl.h — declare DequantizeBlockwise2Bits
  • matmul_nbits_impl.cc — implement Dequantize2BitsKernel + DequantizeBlockwise2Bits with instantiations for <float,float> and <float,MLFloat16>
  • matmul_nbits.cc — replace naive loops in both MatMulNBits<float> and MatMulNBits<MLFloat16> ComputeBUnpacked

Motivation and Context

The bits=2 + float zero_point path (added in #28354) was flagged with // !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!!. It ran ~20× slower than the bits=4 MLAS path because it was a tight scalar for n × for k loop with no threading — the entire N×K dequantization ran on a single core before calling MlasGemmBatch. With 8 intra-op threads this should recover most of that gap.

Benchmark Results

Tested on a 96-core x86_64 Linux machine, ORT 1.27.0 CPU Release build, using typical LLM matrix shapes with block_size=128 and float zero points.

Multi-thread speedup (2-bit dequantization, 1 thread → 8 threads)

Shape (M×K×N) 1-thread (ms) 8-thread (ms) Speedup
1×4096×4096 41.0 8.5 4.84×
32×4096×4096 47.9 8.8 5.46×
1×4096×11008 120.7 24.2 4.99×
32×4096×11008 146.8 28.2 5.21×
1×11008×4096 119.2 24.5 4.87×
32×11008×4096 154.4 28.2 5.47×
1×1024×1024 1.18 0.16 7.61×

2-bit vs 4-bit comparison (ratio = 2-bit / 4-bit; <1.0 means 2-bit is faster)

Shape (M×K×N) Threads 4-bit (ms) 2-bit (ms) Ratio
1×4096×4096 1 52.0 41.0 0.79×
1×4096×4096 8 9.4 8.5 0.90×
1×4096×11008 1 141.6 120.7 0.85×
1×4096×11008 8 26.8 24.2 0.90×
1×11008×4096 1 141.2 119.2 0.84×
1×11008×4096 8 26.6 24.5 0.92×
32×4096×4096 1 56.1 47.9 0.85×
32×4096×4096 8 9.6 8.8 0.92×
1×1024×1024 1 1.66 1.18 0.71×

Key findings:

…ti-threaded kernel

Replace the naive single-threaded scalar loop for 2-bit quantization with
float/MLFloat16 zero points with a multi-threaded implementation using
TrySimpleParallelFor. The new DequantizeBlockwise2Bits function processes
16 elements (one uint32 of packed 2-bit values) per iteration and
distributes work across available threads, matching the parallelism
pattern used by the existing 4-bit DequantizeBlockwise path.

Agent-Logs-Url: https://github.com/microsoft/onnxruntime/sessions/76231b1d-cdea-427a-8824-29293b1d02eb

Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h Outdated
Copilot AI changed the title [WIP] Optimize MatMulNBits performance for bits=2 and float zero_point Optimize MatMulNBits 2-bit + float zero_point CPU dequantization with multi-threaded kernel May 20, 2026
Copilot AI requested a review from tianleiwu May 20, 2026 16:54
@tianleiwu tianleiwu marked this pull request as ready for review May 20, 2026 18:17
@tianleiwu tianleiwu requested a review from Copilot May 20, 2026 18:24
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Optimizes the CPU fallback path for MatMulNBits when bits=2 and zero_points are float/MLFloat16 by replacing a scalar single-thread dequantization loop with a threaded, blockwise dequantization kernel, aiming to remove the large performance regression reported for this configuration.

Changes:

  • Adds a multi-threaded DequantizeBlockwise2Bits kernel that processes 16 values per iteration and parallelizes via TrySimpleParallelFor.
  • Switches the MatMulNBits<float> and MatMulNBits<MLFloat16> unpacked compute paths to use the new 2-bit dequant kernel.
  • Adds a Python benchmark script for measuring 2-bit vs 4-bit performance and thread scaling.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
onnxruntime/test/python/quantization/bench_matmul_2bits.py Adds a standalone benchmark script for MatMulNBits 2-bit float-ZP performance on CPU.
onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc Replaces naive 2-bit float/fp16-ZP dequant loops with DequantizeBlockwise2Bits calls.
onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h Declares DequantizeBlockwise2Bits template API.
onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc Implements the threaded 2-bit dequantization kernel and adds explicit instantiations.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc
Comment thread onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Comment thread onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc Outdated
Comment thread onnxruntime/test/python/quantization/bench_matmul_2bits.py Outdated
@tianleiwu tianleiwu requested a review from justinchuby May 20, 2026 18:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MatMulNBits bits=2 + float zero_point CPU path is ~20× slower than bits=4 (naive scalar fallback)

3 participants