[Qwen3] Allow packed QKV MatMul under QK-Norm via post-MatMul Split#2137
Open
xiaofeihan1 wants to merge 2 commits into
Open
[Qwen3] Allow packed QKV MatMul under QK-Norm via post-MatMul Split#2137xiaofeihan1 wants to merge 2 commits into
xiaofeihan1 wants to merge 2 commits into
Conversation
Previously `use_packed_matmul` was disabled whenever q_norm or k_norm was set, so Qwen3 (and any other QK-Norm architecture) emitted three separate q_proj / k_proj / v_proj MatMulNBits nodes per layer. Allow packed MatMul in this case and insert a single ONNX Split node right after the packed `qkv_proj/MatMul` to recover Q/K/V tensors that feed the existing q_norm/k_norm path. This keeps the per-head SimplifiedLayerNormalization semantics unchanged (math equivalent, quantization unchanged) while reducing 3 MatMulNBits per layer to 1. A single `Split` is preferred over 3 `Slice` nodes because Split reads the packed output once and writes 3 outputs in a single dispatch, avoiding re-reading the same packed tensor 3x each decode step. The packed-bias branch is also gated off when QK-Norm forces unpack, to avoid mismatched shape on a packed Add over a sliced Q tensor. Verified on Qwen3-1.7B (int4, accuracy_level=4) — generated text is byte-identical to the unpacked baseline, and on RTX 5080 (WebGPU EP) gen TPS improves +5.5% (121.6 -> 128.3) with no prefill regression.
Contributor
There was a problem hiding this comment.
Pull request overview
This PR updates the Python model builder’s attention graph construction to allow packed QKV projection even for QK-Norm architectures (e.g., Qwen3), by inserting an ONNX Split immediately after the packed qkv_proj MatMul so downstream Q/K/V-specific paths (including per-head Q/K norm) remain unchanged.
Changes:
- Stop disabling
use_packed_matmulpurely due toq_norm/k_normbeing enabled. - Add a
make_split(...)helper and use it to split packed QKV output into Q, K, V tensors when QK-Norm is active. - Disable the packed-bias Add fusion when QK-Norm is active (since the graph now operates on split Q/K/V tensors).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Previously
use_packed_matmulwas disabled whenever q_norm or k_norm was set, so Qwen3 (and any other QK-Norm architecture) emitted three separate q_proj / k_proj / v_proj MatMulNBits nodes per layer.This PR allows packed QKV MatMul for QK-Norm models by inserting a single ONNX
Splitafter the packed projection so the existing per-head Q/K SimplifiedLayerNormalization path is unchanged. Math is equivalent; quantization is unchanged.Graph change (per attention layer)
Before:
After (Split placed after the optional packed bias-Add so packed bias fusion is preserved):
Net per layer: −2 MatMulNBits, +1 Split. For Qwen3-4B (36 layers): −72 MatMulNBits, +36 Split, total ONNX nodes 700 → 665.
Why Split (not Slice)
Single read of the packed tensor, single dispatch with 3 outputs — avoids re-reading the same packed Q/K/V tensor three times per decode step.
Verification — Qwen3-4B int4 (RelWithDebInfo build, accuracy_level=4)
Compared two builds of the same source model:
664a61b1, QK-Norm forces 3 independent MatMuls)Output token sequences are byte-identical on both GPUs (no accuracy drift).
WebGPU GPU profile (aggregate: 1 prefill + 50 decode × 2 phases)
End-to-end perf — Qwen3-4B prefill-1000, 5 runs
(NV result matches the −2.8% GPU profile measurement; iGPU receives a similar relative gain. Prompt TPS and TTFT are unchanged on both vendors.)
Compatibility
q_norm && k_norm).use_matmul_in_attn): unchanged.