Skip to content

Commit ba161de

Browse files
Rewrite FP4 matmul as NVFP4 FP4xFP4 GEMM
Restructures the challenge so a submission directly verifies AutoKernel's FP4 matmul claim (Table 5 of the paper): both operands are packed FP4 E2M1 with E4M3 per-block scales and a per-tensor FP32 alpha, matching the NVFP4 layout used by CUTLASS and qutlass. Previous revision was W4A16 weight-only quant, which cannot reach the TF/s regime the paper reports because x was still FP16. Key changes: - Both x and w are packed FP4 uint8 (nibbles); block size = 16. - Scales are raw E4M3 bytes (torch.float8_e4m3fn bit patterns). - Reference is a pure FP32 dequant + matmul oracle. - Performance shape (M=2048, N=18432, K=3072) taken verbatim from the Triton vs CUTLASS row in Table 5 so TF/s is directly comparable. - Tolerances loosened to atol=0.1, rtol=0.05 to admit FP16 accumulation used by tensor-core paths. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e099a71 commit ba161de

8 files changed

Lines changed: 187 additions & 146 deletions

File tree

Lines changed: 58 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
<p>
2-
Implement an FP4 weight-only quantized matrix multiplication, the kernel at the heart of
3-
modern low-precision LLM inference on Hopper and Blackwell GPUs. Given a float16 activation
4-
matrix <code>x</code> of shape <code>M &times; K</code> and a weight matrix stored in packed
5-
FP4 E2M1 format, compute <code>y = x &times; W<sup>T</sup></code> of shape
6-
<code>M &times; N</code>, where <code>W</code> is the dequantized float16 weight matrix of
7-
shape <code>N &times; K</code>.
2+
Implement an <strong>NVFP4</strong> matrix multiplication, the low-precision GEMM that powers
3+
state-of-the-art LLM inference on Hopper and Blackwell GPUs. Both operands are stored in 4-bit
4+
floating point (FP4 E2M1) with per-block FP8 (E4M3) scales along the reduction dimension, plus
5+
a single per-tensor FP32 scale. Given packed activations <code>x_q</code> of shape
6+
<code>M &times; K</code>, packed weights <code>w_q</code> of shape <code>N &times; K</code>,
7+
and their respective block scales, compute
8+
<code>y = alpha &times; (x &times; w<sup>T</sup>)</code> of shape <code>M &times; N</code>
9+
in float16.
810
</p>
911

1012
<p>
11-
<strong>FP4 E2M1 format:</strong> Each weight is encoded in 4 bits as
12-
[sign | exponent (2 bits) | mantissa (1 bit)], representing one of sixteen values:
13+
<strong>FP4 E2M1 encoding:</strong> Each weight is 4 bits
14+
[sign | exp (2 bits) | mantissa (1 bit)] representing one of sixteen values:
1315
<code>{&plusmn;0, &plusmn;0.5, &plusmn;1, &plusmn;1.5, &plusmn;2, &plusmn;3, &plusmn;4, &plusmn;6}</code>.
1416
The nibble-to-value mapping is:
1517
</p>
@@ -25,89 +27,92 @@
2527
</pre>
2628

2729
<p>
28-
<strong>Packing:</strong> Each byte of <code>w_q</code> stores two FP4 weights. The high
29-
nibble (bits 7&ndash;4) holds <code>w[n, 2i]</code> and the low nibble (bits 3&ndash;0) holds
30-
<code>w[n, 2i+1]</code>.
30+
<strong>Packing:</strong> Each byte of <code>x_q</code> / <code>w_q</code> stores two FP4
31+
values. The high nibble (bits 7&ndash;4) holds the even-index value and the low nibble
32+
(bits 3&ndash;0) holds the odd-index value.
3133
</p>
3234

3335
<p>
34-
<strong>Dequantization:</strong> Weights are dequantized group-wise. Each contiguous block of
35-
<code>group_size</code> weights along the <code>K</code> dimension shares one float16 scale:
36+
<strong>Block scales:</strong> Each contiguous block of <strong>16</strong> FP4 values along
37+
the <code>K</code> dimension shares one E4M3 (float8) scale. The scale tensors
38+
<code>x_scales</code> and <code>w_scales</code> are passed as raw uint8 bytes holding the
39+
E4M3 bit patterns. Dequantization is:
3640
</p>
3741
<pre>
38-
W[n, k] = fp4_decode(w_q_nibble[n, k]) * scales[n, k // group_size]
42+
x[m, k] = fp4_decode(x_q_nibble[m, k]) * e4m3_decode(x_scales[m, k // 16])
43+
w[n, k] = fp4_decode(w_q_nibble[n, k]) * e4m3_decode(w_scales[n, k // 16])
44+
y[m, n] = alpha * sum_k x[m, k] * w[n, k]
3945
</pre>
4046

4147
<h2>Implementation Requirements</h2>
4248
<ul>
4349
<li>Use only native features (external libraries are not permitted)</li>
4450
<li>The <code>solve</code> function signature must remain unchanged</li>
45-
<li>The final result must be stored in <code>y</code></li>
51+
<li>The final result must be stored in <code>y</code> as float16</li>
4652
</ul>
4753

4854
<h2>Example</h2>
4955
<p>
50-
Input (<code>M</code> = 2, <code>N</code> = 4, <code>K</code> = 4, <code>group_size</code> = 2):
56+
Input (<code>M</code> = 2, <code>N</code> = 2, <code>K</code> = 16, <code>alpha</code> = 1.0):
5157
</p>
5258
<p>
53-
Activations \(x\) (float16, \(2 \times 4\)):
59+
Packed activations \(x\_q\) (uint8, \(2 \times 8\)) and decoded FP4 values (each row has
60+
sixteen values):
5461
\[
62+
x\_q =
5563
\begin{bmatrix}
56-
1.0 & 0.0 & 1.0 & 0.0 \\
57-
0.0 & 1.0 & 0.0 & 1.0
58-
\end{bmatrix}
59-
\]
60-
Packed weights \(w\_q\) (uint8, \(4 \times 2\)) decoded via the FP4 E2M1 table:
61-
\[
62-
\begin{bmatrix}
63-
\texttt{0x22} & \texttt{0x22} \\
64-
\texttt{0x44} & \texttt{0x44} \\
65-
\texttt{0xAA} & \texttt{0xAA} \\
66-
\texttt{0x00} & \texttt{0x00}
64+
\texttt{0x22} & \cdots & \texttt{0x22} \\
65+
\texttt{0x11} & \cdots & \texttt{0x11}
6766
\end{bmatrix}
6867
\;\Rightarrow\;
69-
W_{\text{fp4}} =
68+
x_{\text{fp4}} =
7069
\begin{bmatrix}
71-
1.0 & 1.0 & 1.0 & 1.0 \\
72-
2.0 & 2.0 & 2.0 & 2.0 \\
73-
-1.0 & -1.0 & -1.0 & -1.0 \\
74-
0.0 & 0.0 & 0.0 & 0.0
70+
1.0 & 1.0 & \cdots & 1.0 \\
71+
0.5 & 0.5 & \cdots & 0.5
7572
\end{bmatrix}
7673
\]
77-
Scales (float16, \(4 \times 2\), all entries 0.5):
74+
Packed weights \(w\_q\) (uint8, \(2 \times 8\)):
7875
\[
76+
w\_q =
7977
\begin{bmatrix}
80-
0.5 & 0.5 \\
81-
0.5 & 0.5 \\
82-
0.5 & 0.5 \\
83-
0.5 & 0.5
78+
\texttt{0x44} & \cdots & \texttt{0x44} \\
79+
\texttt{0xAA} & \cdots & \texttt{0xAA}
8480
\end{bmatrix}
8581
\;\Rightarrow\;
86-
W_{\text{dequant}} =
82+
w_{\text{fp4}} =
8783
\begin{bmatrix}
88-
0.5 & 0.5 & 0.5 & 0.5 \\
89-
1.0 & 1.0 & 1.0 & 1.0 \\
90-
-0.5 & -0.5 & -0.5 & -0.5 \\
91-
0.0 & 0.0 & 0.0 & 0.0
84+
2.0 & 2.0 & \cdots & 2.0 \\
85+
-1.0 & -1.0 & \cdots & -1.0
9286
\end{bmatrix}
9387
\]
94-
Output \(y = x \times W^T\) (float16, \(2 \times 4\)):
88+
Block scales (one block per row since <code>K</code> = 16): both
89+
<code>x_scales</code> and <code>w_scales</code> are uint8 \(2 \times 1\) with every byte
90+
equal to <code>0x38</code>, which is the E4M3 bit pattern for 1.0. The dequantized operands
91+
therefore equal the FP4 values above.
92+
</p>
93+
<p>
94+
Output \(y = \alpha \cdot (x \times w^T)\) (float16, \(2 \times 2\)):
9595
\[
9696
\begin{bmatrix}
97-
1.0 & 2.0 & -1.0 & 0.0 \\
98-
1.0 & 2.0 & -1.0 & 0.0
97+
\sum 1.0 \cdot 2.0 & \sum 1.0 \cdot (-1.0) \\
98+
\sum 0.5 \cdot 2.0 & \sum 0.5 \cdot (-1.0)
99+
\end{bmatrix}
100+
=
101+
\begin{bmatrix}
102+
32.0 & -16.0 \\
103+
16.0 & -8.0
99104
\end{bmatrix}
100105
\]
101106
</p>
102107

103108
<h2>Constraints</h2>
104109
<ul>
105-
<li>1 &le; <code>M</code>, <code>N</code> &le; 8,192</li>
106-
<li>1 &le; <code>K</code> &le; 8,192</li>
107-
<li><code>K</code> is divisible by <code>2</code> and by <code>group_size</code></li>
108-
<li><code>group_size</code> &isin; {2, 4, 8, 16, 32}</li>
110+
<li>1 &le; <code>M</code>, <code>N</code> &le; 32,768</li>
111+
<li>16 &le; <code>K</code> &le; 32,768</li>
112+
<li><code>K</code> is divisible by <strong>16</strong> (the NVFP4 block size)</li>
109113
<li>All tensors are stored in row-major order</li>
110-
<li>Input dtype: <code>x</code> and <code>scales</code> are float16; <code>w_q</code> is uint8</li>
111-
<li>Output dtype: <code>y</code> is float16</li>
112-
<li>Performance is measured with <code>M</code> = 2,048, <code>N</code> = 8,192, <code>K</code> = 3,072, <code>group_size</code> = 32</li>
114+
<li>Inputs: <code>x_q</code>, <code>w_q</code>, <code>x_scales</code>, <code>w_scales</code>
115+
are <code>uint8</code>; <code>alpha</code> is <code>float32</code></li>
116+
<li>Output: <code>y</code> is <code>float16</code></li>
117+
<li>Performance is measured with <code>M</code> = 2,048, <code>N</code> = 18,432, <code>K</code> = 3,072</li>
113118
</ul>

0 commit comments

Comments
 (0)