|
1 | 1 | <p> |
2 | | - Implement an FP4 weight-only quantized matrix multiplication, the kernel at the heart of |
3 | | - modern low-precision LLM inference on Hopper and Blackwell GPUs. Given a float16 activation |
4 | | - matrix <code>x</code> of shape <code>M × K</code> and a weight matrix stored in packed |
5 | | - FP4 E2M1 format, compute <code>y = x × W<sup>T</sup></code> of shape |
6 | | - <code>M × N</code>, where <code>W</code> is the dequantized float16 weight matrix of |
7 | | - shape <code>N × K</code>. |
| 2 | + Implement an <strong>NVFP4</strong> matrix multiplication, the low-precision GEMM that powers |
| 3 | + state-of-the-art LLM inference on Hopper and Blackwell GPUs. Both operands are stored in 4-bit |
| 4 | + floating point (FP4 E2M1) with per-block FP8 (E4M3) scales along the reduction dimension, plus |
| 5 | + a single per-tensor FP32 scale. Given packed activations <code>x_q</code> of shape |
| 6 | + <code>M × K</code>, packed weights <code>w_q</code> of shape <code>N × K</code>, |
| 7 | + and their respective block scales, compute |
| 8 | + <code>y = alpha × (x × w<sup>T</sup>)</code> of shape <code>M × N</code> |
| 9 | + in float16. |
8 | 10 | </p> |
9 | 11 |
|
10 | 12 | <p> |
11 | | - <strong>FP4 E2M1 format:</strong> Each weight is encoded in 4 bits as |
12 | | - [sign | exponent (2 bits) | mantissa (1 bit)], representing one of sixteen values: |
| 13 | + <strong>FP4 E2M1 encoding:</strong> Each weight is 4 bits |
| 14 | + [sign | exp (2 bits) | mantissa (1 bit)] representing one of sixteen values: |
13 | 15 | <code>{±0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}</code>. |
14 | 16 | The nibble-to-value mapping is: |
15 | 17 | </p> |
|
25 | 27 | </pre> |
26 | 28 |
|
27 | 29 | <p> |
28 | | - <strong>Packing:</strong> Each byte of <code>w_q</code> stores two FP4 weights. The high |
29 | | - nibble (bits 7–4) holds <code>w[n, 2i]</code> and the low nibble (bits 3–0) holds |
30 | | - <code>w[n, 2i+1]</code>. |
| 30 | + <strong>Packing:</strong> Each byte of <code>x_q</code> / <code>w_q</code> stores two FP4 |
| 31 | + values. The high nibble (bits 7–4) holds the even-index value and the low nibble |
| 32 | + (bits 3–0) holds the odd-index value. |
31 | 33 | </p> |
32 | 34 |
|
33 | 35 | <p> |
34 | | - <strong>Dequantization:</strong> Weights are dequantized group-wise. Each contiguous block of |
35 | | - <code>group_size</code> weights along the <code>K</code> dimension shares one float16 scale: |
| 36 | + <strong>Block scales:</strong> Each contiguous block of <strong>16</strong> FP4 values along |
| 37 | + the <code>K</code> dimension shares one E4M3 (float8) scale. The scale tensors |
| 38 | + <code>x_scales</code> and <code>w_scales</code> are passed as raw uint8 bytes holding the |
| 39 | + E4M3 bit patterns. Dequantization is: |
36 | 40 | </p> |
37 | 41 | <pre> |
38 | | -W[n, k] = fp4_decode(w_q_nibble[n, k]) * scales[n, k // group_size] |
| 42 | +x[m, k] = fp4_decode(x_q_nibble[m, k]) * e4m3_decode(x_scales[m, k // 16]) |
| 43 | +w[n, k] = fp4_decode(w_q_nibble[n, k]) * e4m3_decode(w_scales[n, k // 16]) |
| 44 | +y[m, n] = alpha * sum_k x[m, k] * w[n, k] |
39 | 45 | </pre> |
40 | 46 |
|
41 | 47 | <h2>Implementation Requirements</h2> |
42 | 48 | <ul> |
43 | 49 | <li>Use only native features (external libraries are not permitted)</li> |
44 | 50 | <li>The <code>solve</code> function signature must remain unchanged</li> |
45 | | - <li>The final result must be stored in <code>y</code></li> |
| 51 | + <li>The final result must be stored in <code>y</code> as float16</li> |
46 | 52 | </ul> |
47 | 53 |
|
48 | 54 | <h2>Example</h2> |
49 | 55 | <p> |
50 | | - Input (<code>M</code> = 2, <code>N</code> = 4, <code>K</code> = 4, <code>group_size</code> = 2): |
| 56 | + Input (<code>M</code> = 2, <code>N</code> = 2, <code>K</code> = 16, <code>alpha</code> = 1.0): |
51 | 57 | </p> |
52 | 58 | <p> |
53 | | - Activations \(x\) (float16, \(2 \times 4\)): |
| 59 | + Packed activations \(x\_q\) (uint8, \(2 \times 8\)) and decoded FP4 values (each row has |
| 60 | + sixteen values): |
54 | 61 | \[ |
| 62 | + x\_q = |
55 | 63 | \begin{bmatrix} |
56 | | - 1.0 & 0.0 & 1.0 & 0.0 \\ |
57 | | - 0.0 & 1.0 & 0.0 & 1.0 |
58 | | - \end{bmatrix} |
59 | | - \] |
60 | | - Packed weights \(w\_q\) (uint8, \(4 \times 2\)) decoded via the FP4 E2M1 table: |
61 | | - \[ |
62 | | - \begin{bmatrix} |
63 | | - \texttt{0x22} & \texttt{0x22} \\ |
64 | | - \texttt{0x44} & \texttt{0x44} \\ |
65 | | - \texttt{0xAA} & \texttt{0xAA} \\ |
66 | | - \texttt{0x00} & \texttt{0x00} |
| 64 | + \texttt{0x22} & \cdots & \texttt{0x22} \\ |
| 65 | + \texttt{0x11} & \cdots & \texttt{0x11} |
67 | 66 | \end{bmatrix} |
68 | 67 | \;\Rightarrow\; |
69 | | - W_{\text{fp4}} = |
| 68 | + x_{\text{fp4}} = |
70 | 69 | \begin{bmatrix} |
71 | | - 1.0 & 1.0 & 1.0 & 1.0 \\ |
72 | | - 2.0 & 2.0 & 2.0 & 2.0 \\ |
73 | | - -1.0 & -1.0 & -1.0 & -1.0 \\ |
74 | | - 0.0 & 0.0 & 0.0 & 0.0 |
| 70 | + 1.0 & 1.0 & \cdots & 1.0 \\ |
| 71 | + 0.5 & 0.5 & \cdots & 0.5 |
75 | 72 | \end{bmatrix} |
76 | 73 | \] |
77 | | - Scales (float16, \(4 \times 2\), all entries 0.5): |
| 74 | + Packed weights \(w\_q\) (uint8, \(2 \times 8\)): |
78 | 75 | \[ |
| 76 | + w\_q = |
79 | 77 | \begin{bmatrix} |
80 | | - 0.5 & 0.5 \\ |
81 | | - 0.5 & 0.5 \\ |
82 | | - 0.5 & 0.5 \\ |
83 | | - 0.5 & 0.5 |
| 78 | + \texttt{0x44} & \cdots & \texttt{0x44} \\ |
| 79 | + \texttt{0xAA} & \cdots & \texttt{0xAA} |
84 | 80 | \end{bmatrix} |
85 | 81 | \;\Rightarrow\; |
86 | | - W_{\text{dequant}} = |
| 82 | + w_{\text{fp4}} = |
87 | 83 | \begin{bmatrix} |
88 | | - 0.5 & 0.5 & 0.5 & 0.5 \\ |
89 | | - 1.0 & 1.0 & 1.0 & 1.0 \\ |
90 | | - -0.5 & -0.5 & -0.5 & -0.5 \\ |
91 | | - 0.0 & 0.0 & 0.0 & 0.0 |
| 84 | + 2.0 & 2.0 & \cdots & 2.0 \\ |
| 85 | + -1.0 & -1.0 & \cdots & -1.0 |
92 | 86 | \end{bmatrix} |
93 | 87 | \] |
94 | | - Output \(y = x \times W^T\) (float16, \(2 \times 4\)): |
| 88 | + Block scales (one block per row since <code>K</code> = 16): both |
| 89 | + <code>x_scales</code> and <code>w_scales</code> are uint8 \(2 \times 1\) with every byte |
| 90 | + equal to <code>0x38</code>, which is the E4M3 bit pattern for 1.0. The dequantized operands |
| 91 | + therefore equal the FP4 values above. |
| 92 | +</p> |
| 93 | +<p> |
| 94 | + Output \(y = \alpha \cdot (x \times w^T)\) (float16, \(2 \times 2\)): |
95 | 95 | \[ |
96 | 96 | \begin{bmatrix} |
97 | | - 1.0 & 2.0 & -1.0 & 0.0 \\ |
98 | | - 1.0 & 2.0 & -1.0 & 0.0 |
| 97 | + \sum 1.0 \cdot 2.0 & \sum 1.0 \cdot (-1.0) \\ |
| 98 | + \sum 0.5 \cdot 2.0 & \sum 0.5 \cdot (-1.0) |
| 99 | + \end{bmatrix} |
| 100 | + = |
| 101 | + \begin{bmatrix} |
| 102 | + 32.0 & -16.0 \\ |
| 103 | + 16.0 & -8.0 |
99 | 104 | \end{bmatrix} |
100 | 105 | \] |
101 | 106 | </p> |
102 | 107 |
|
103 | 108 | <h2>Constraints</h2> |
104 | 109 | <ul> |
105 | | - <li>1 ≤ <code>M</code>, <code>N</code> ≤ 8,192</li> |
106 | | - <li>1 ≤ <code>K</code> ≤ 8,192</li> |
107 | | - <li><code>K</code> is divisible by <code>2</code> and by <code>group_size</code></li> |
108 | | - <li><code>group_size</code> ∈ {2, 4, 8, 16, 32}</li> |
| 110 | + <li>1 ≤ <code>M</code>, <code>N</code> ≤ 32,768</li> |
| 111 | + <li>16 ≤ <code>K</code> ≤ 32,768</li> |
| 112 | + <li><code>K</code> is divisible by <strong>16</strong> (the NVFP4 block size)</li> |
109 | 113 | <li>All tensors are stored in row-major order</li> |
110 | | - <li>Input dtype: <code>x</code> and <code>scales</code> are float16; <code>w_q</code> is uint8</li> |
111 | | - <li>Output dtype: <code>y</code> is float16</li> |
112 | | - <li>Performance is measured with <code>M</code> = 2,048, <code>N</code> = 8,192, <code>K</code> = 3,072, <code>group_size</code> = 32</li> |
| 114 | + <li>Inputs: <code>x_q</code>, <code>w_q</code>, <code>x_scales</code>, <code>w_scales</code> |
| 115 | + are <code>uint8</code>; <code>alpha</code> is <code>float32</code></li> |
| 116 | + <li>Output: <code>y</code> is <code>float16</code></li> |
| 117 | + <li>Performance is measured with <code>M</code> = 2,048, <code>N</code> = 18,432, <code>K</code> = 3,072</li> |
113 | 118 | </ul> |
0 commit comments