Skip to content

MatMulNBits bits=2 + float zero_point CPU path is ~20× slower than bits=4 (naive scalar fallback) #28552

@justinchuby

Description

@justinchuby

Background

microsoft/onnxruntime#28354 (ORT 1.27) added a CPU fallback for MatMulNBits with bits=2 and float-valued zero_points, enabling AMD QAD / Tencent SEQ-style 2-bit codebooks whose offset is non-integer (e.g. 1.5). The implementation is a tight scalar for n × for k loop that dequantizes B in full before calling MlasGemmBatch, and the source itself explicitly flags this as preliminary:

// !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!!

(contrib_ops/cpu/quantization/matmul_nbits.cc, inside ComputeBUnpacked for nbits_ == 2 with zero_points->IsDataType<float>()).

Observed impact

End-to-end decode of tencent/Hy-MT1.5-1.8B-2bit exported by mobius PR #302 (32 layers × 7 MatMulNBits nodes per layer, all bits=2 + float zp + block_size=128):

representation bytes/elt ORT path CPU throughput
bits=4 packed uint8 zp 0.5 MLAS fused path ~5 tok/s
bits=2 + float zp (this issue) 0.25 naive scalar fallback ~0.24 tok/s

That's ~20× slower than the bits=4 representation for the same dequantized values. Single-layer correctness is fine (max_abs = 2.5e-3 vs HF reference matmul); the bottleneck is purely the dequant kernel.

Hardware / build

Repro

git clone https://github.com/onnxruntime/mobius.git
cd mobius && git checkout add-q1-0-gguf-support && pip install -e .
hf download AngelSlim/Hy-MT1.5-1.8B-2bit-GGUF Hy-MT1.5-1.8B-2bit.gguf --local-dir /tmp/gguf
python -c "
from mobius.integrations.gguf._builder import build_from_gguf
pkg = build_from_gguf('/tmp/gguf/Hy-MT1.5-1.8B-2bit.gguf', keep_quantized=True, dtype='f32')
pkg.save('/tmp/hymt-q2')
"
# Then run greedy decode — ~0.24 tok/s on CPU.

What would help

Either (or both):

  1. An MLAS fast path for bits=2 with float zp (mirroring the existing bits=4 packed-uint8 path: blockwise B dequant fused into the GEMM kernel with SIMD, instead of a full pre-dequant pass).
  2. Extending the existing LUT GEMM kernel (the path PR Add float zero point support for 2-bit LUT GEMM in MatMulNBits #28354 already wired up for the QAD case via MlasInitLutGemmKernelConfig) so it covers general inputs, not just the QAD-specific code path. The LUT decode is essentially what we need — it just needs to fall back from the LUT-GEMM-eligible criteria to a general-purpose code path that still benefits from SIMD.

Why mobius cares right now

Until this lands, mobius PR #302 ships the Tencent Q1_0 → ONNX path with a tencent_q1_0_use_native_2bit flag that defaults to false (i.e., inflate to bits=4 for usable throughput). Setting the flag to true gives the smaller, more semantically faithful representation but is currently unusable for interactive decode. Once a fast bits=2 + float-zp kernel exists we'd flip the default.

Metadata

Metadata

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions