|
| 1 | +<p> |
| 2 | + Implement a weight-only INT4 quantized matrix multiplication (W4A16), a core kernel used in |
| 3 | + modern LLM inference. Given a float16 activation matrix <code>x</code> of shape |
| 4 | + <code>M × K</code> and a weight matrix stored in packed INT4 format, compute the output |
| 5 | + matrix <code>y = x × W<sup>T</sup></code> of shape <code>M × N</code>, where |
| 6 | + <code>W</code> is the dequantized float16 weight matrix of shape <code>N × K</code>. |
| 7 | +</p> |
| 8 | + |
| 9 | +<svg width="700" height="400" viewBox="0 0 700 400" xmlns="http://www.w3.org/2000/svg" |
| 10 | + style="display:block; margin:20px auto; font-family:monospace;"> |
| 11 | + <rect width="700" height="400" fill="#222" rx="10"/> |
| 12 | + <defs> |
| 13 | + <marker id="arr" markerWidth="8" markerHeight="8" refX="6" refY="3" orient="auto"> |
| 14 | + <path d="M0,0 L0,6 L8,3 z" fill="#aaa"/> |
| 15 | + </marker> |
| 16 | + </defs> |
| 17 | + |
| 18 | + <!-- ============================================================ --> |
| 19 | + <!-- ROW 1: UNPACK — packed byte → two unsigned nibbles → signed --> |
| 20 | + <!-- ============================================================ --> |
| 21 | + <text x="18" y="20" fill="#666" font-size="10">STEP 1: UNPACK</text> |
| 22 | + |
| 23 | + <!-- Packed byte --> |
| 24 | + <text x="80" y="48" fill="#ccc" font-size="11" text-anchor="middle">w_q[n, i]</text> |
| 25 | + <rect x="20" y="56" width="120" height="32" fill="#1a3a5c" rx="4" stroke="#4a9edd" stroke-width="1.5"/> |
| 26 | + <line x1="80" y1="56" x2="80" y2="88" stroke="#4a9edd" stroke-width="1" stroke-dasharray="3,2"/> |
| 27 | + <text x="50" y="77" fill="#4a9edd" font-size="10" text-anchor="middle">hi 7:4</text> |
| 28 | + <text x="110" y="77" fill="#7ec87e" font-size="10" text-anchor="middle">lo 3:0</text> |
| 29 | + |
| 30 | + <!-- Arrow right --> |
| 31 | + <text x="160" y="77" fill="#aaa" font-size="14" text-anchor="middle">→</text> |
| 32 | + |
| 33 | + <!-- Unsigned nibbles --> |
| 34 | + <rect x="180" y="56" width="50" height="32" fill="#1a3a5c" rx="4" stroke="#4a9edd" stroke-width="1.5"/> |
| 35 | + <text x="205" y="77" fill="#4a9edd" font-size="10" text-anchor="middle">9</text> |
| 36 | + <rect x="236" y="56" width="50" height="32" fill="#1a4a1a" rx="4" stroke="#7ec87e" stroke-width="1.5"/> |
| 37 | + <text x="261" y="77" fill="#7ec87e" font-size="10" text-anchor="middle">10</text> |
| 38 | + |
| 39 | + <!-- "- 8" arrow --> |
| 40 | + <text x="310" y="77" fill="#ccc" font-size="11" text-anchor="middle">− 8</text> |
| 41 | + <text x="345" y="77" fill="#aaa" font-size="14" text-anchor="middle">→</text> |
| 42 | + |
| 43 | + <!-- Signed int4 --> |
| 44 | + <rect x="365" y="56" width="50" height="32" fill="#3a2a1a" rx="4" stroke="#e0a040" stroke-width="1.5"/> |
| 45 | + <text x="390" y="77" fill="#e0a040" font-size="10" text-anchor="middle">+1</text> |
| 46 | + <rect x="421" y="56" width="50" height="32" fill="#3a2a1a" rx="4" stroke="#e0a040" stroke-width="1.5"/> |
| 47 | + <text x="446" y="77" fill="#e0a040" font-size="10" text-anchor="middle">+2</text> |
| 48 | + |
| 49 | + <text x="540" y="77" fill="#888" font-size="10" text-anchor="middle">signed int4 [−8, 7]</text> |
| 50 | + |
| 51 | + <!-- ============================================================ --> |
| 52 | + <!-- ROW 2: GROUP-WISE SCALING — show K=8, group_size=4 --> |
| 53 | + <!-- ============================================================ --> |
| 54 | + <text x="18" y="112" fill="#666" font-size="10">STEP 2: DEQUANTIZE (example: one row n, K=8, group_size=4)</text> |
| 55 | + |
| 56 | + <!-- K-axis label --> |
| 57 | + <text x="350" y="136" fill="#888" font-size="10" text-anchor="middle">k →</text> |
| 58 | + |
| 59 | + <!-- Group 0 bracket + cells --> |
| 60 | + <text x="145" y="136" fill="#c060e0" font-size="9" text-anchor="middle">group 0: scale[n, 0]</text> |
| 61 | + <rect x="58" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/> |
| 62 | + <text x="81" y="161" fill="#e0a040" font-size="10" text-anchor="middle">+1</text> |
| 63 | + <rect x="108" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/> |
| 64 | + <text x="131" y="161" fill="#e0a040" font-size="10" text-anchor="middle">+2</text> |
| 65 | + <rect x="158" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/> |
| 66 | + <text x="181" y="161" fill="#e0a040" font-size="10" text-anchor="middle">−1</text> |
| 67 | + <rect x="208" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/> |
| 68 | + <text x="231" y="161" fill="#e0a040" font-size="10" text-anchor="middle">+3</text> |
| 69 | + <!-- Group 0 bracket --> |
| 70 | + <rect x="56" y="140" width="200" height="32" rx="4" fill="none" stroke="#c060e0" stroke-width="1.5" stroke-dasharray="4,2"/> |
| 71 | + |
| 72 | + <!-- Group 1 bracket + cells --> |
| 73 | + <text x="385" y="136" fill="#c060e0" font-size="9" text-anchor="middle">group 1: scale[n, 1]</text> |
| 74 | + <rect x="298" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/> |
| 75 | + <text x="321" y="161" fill="#e0a040" font-size="10" text-anchor="middle">0</text> |
| 76 | + <rect x="348" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/> |
| 77 | + <text x="371" y="161" fill="#e0a040" font-size="10" text-anchor="middle">−3</text> |
| 78 | + <rect x="398" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/> |
| 79 | + <text x="421" y="161" fill="#e0a040" font-size="10" text-anchor="middle">+7</text> |
| 80 | + <rect x="448" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/> |
| 81 | + <text x="471" y="161" fill="#e0a040" font-size="10" text-anchor="middle">−2</text> |
| 82 | + <!-- Group 1 bracket --> |
| 83 | + <rect x="296" y="140" width="200" height="32" rx="4" fill="none" stroke="#c060e0" stroke-width="1.5" stroke-dasharray="4,2"/> |
| 84 | + |
| 85 | + <!-- "int4" label on left --> |
| 86 | + <text x="30" y="161" fill="#e0a040" font-size="9">int4</text> |
| 87 | + |
| 88 | + <!-- Multiply arrows down --> |
| 89 | + <text x="156" y="190" fill="#ccc" font-size="12" text-anchor="middle">× scale[n, 0]</text> |
| 90 | + <text x="396" y="190" fill="#ccc" font-size="12" text-anchor="middle">× scale[n, 1]</text> |
| 91 | + <line x1="156" y1="172" x2="156" y2="198" stroke="#aaa" stroke-width="1" stroke-dasharray="3,2"/> |
| 92 | + <line x1="396" y1="172" x2="396" y2="198" stroke="#aaa" stroke-width="1" stroke-dasharray="3,2"/> |
| 93 | + |
| 94 | + <!-- Dequantized row --> |
| 95 | + <text x="30" y="217" fill="#40c080" font-size="9">fp16</text> |
| 96 | + <rect x="56" y="202" width="200" height="28" fill="#1a3a2a" rx="4" stroke="#40c080" stroke-width="1.5"/> |
| 97 | + <text x="156" y="221" fill="#40c080" font-size="10" text-anchor="middle">W[n, 0..3] float16</text> |
| 98 | + <rect x="296" y="202" width="200" height="28" fill="#1a3a2a" rx="4" stroke="#40c080" stroke-width="1.5"/> |
| 99 | + <text x="396" y="221" fill="#40c080" font-size="10" text-anchor="middle">W[n, 4..7] float16</text> |
| 100 | + |
| 101 | + <!-- Formula --> |
| 102 | + <text x="275" y="252" fill="#ccc" font-size="10" text-anchor="middle">W[n, k] = (nibble − 8) × scales[n, k // group_size]</text> |
| 103 | + |
| 104 | + <!-- ============================================================ --> |
| 105 | + <!-- ROW 3: MATMUL --> |
| 106 | + <!-- ============================================================ --> |
| 107 | + <text x="18" y="280" fill="#666" font-size="10">STEP 3: MATMUL</text> |
| 108 | + |
| 109 | + <!-- x box --> |
| 110 | + <rect x="60" y="296" width="80" height="60" fill="#1a3a5c" rx="4" stroke="#4a9edd" stroke-width="1.5"/> |
| 111 | + <text x="100" y="322" fill="#4a9edd" font-size="10" text-anchor="middle">x [M×K]</text> |
| 112 | + <text x="100" y="340" fill="#4a9edd" font-size="9" text-anchor="middle">float16</text> |
| 113 | + |
| 114 | + <!-- multiply sign --> |
| 115 | + <text x="162" y="330" fill="#ccc" font-size="16" text-anchor="middle">×</text> |
| 116 | + |
| 117 | + <!-- W^T box --> |
| 118 | + <rect x="185" y="296" width="100" height="60" fill="#1a3a2a" rx="4" stroke="#40c080" stroke-width="1.5"/> |
| 119 | + <text x="235" y="322" fill="#40c080" font-size="10" text-anchor="middle">Wᵀ [K×N]</text> |
| 120 | + <text x="235" y="340" fill="#40c080" font-size="9" text-anchor="middle">float16</text> |
| 121 | + |
| 122 | + <!-- equals sign --> |
| 123 | + <text x="310" y="330" fill="#ccc" font-size="16" text-anchor="middle">=</text> |
| 124 | + |
| 125 | + <!-- y output box --> |
| 126 | + <rect x="335" y="296" width="90" height="60" fill="#3a1a1a" rx="4" stroke="#e05050" stroke-width="1.5"/> |
| 127 | + <text x="380" y="322" fill="#e05050" font-size="10" text-anchor="middle">y [M×N]</text> |
| 128 | + <text x="380" y="340" fill="#e05050" font-size="9" text-anchor="middle">float16</text> |
| 129 | + |
| 130 | + <!-- Arrow from dequant to W^T --> |
| 131 | + <line x1="235" y1="240" x2="235" y2="294" stroke="#40c080" stroke-width="1.5" stroke-dasharray="4,2" marker-end="url(#arr)"/> |
| 132 | + <text x="260" y="270" fill="#40c080" font-size="9">dequantized</text> |
| 133 | +</svg> |
| 134 | + |
| 135 | +<p> |
| 136 | + <strong>Packing format:</strong> Each byte of <code>w_q</code> stores two INT4 weights. The |
| 137 | + high nibble (bits 7–4) holds weight <code>w[n, 2i]</code> and the low nibble (bits |
| 138 | + 3–0) holds <code>w[n, 2i+1]</code>. INT4 values are stored unsigned in the range |
| 139 | + [0, 15] with an offset of 8, so the signed weight is <code>nibble − 8</code>, |
| 140 | + giving values in [−8, 7]. |
| 141 | +</p> |
| 142 | + |
| 143 | +<p> |
| 144 | + <strong>Dequantization:</strong> Weights are dequantized group-wise. Each contiguous block of |
| 145 | + <code>group_size</code> weights along the <code>K</code> dimension shares one float16 scale: |
| 146 | +</p> |
| 147 | +<pre> |
| 148 | +W[n, k] = (w_q_nibble[n, k] - 8) * scales[n, k // group_size] |
| 149 | +</pre> |
| 150 | + |
| 151 | +<h2>Implementation Requirements</h2> |
| 152 | +<ul> |
| 153 | + <li>Use only native features (external libraries are not permitted)</li> |
| 154 | + <li>The <code>solve</code> function signature must remain unchanged</li> |
| 155 | + <li>The final result must be stored in <code>y</code></li> |
| 156 | +</ul> |
| 157 | + |
| 158 | +<h2>Example</h2> |
| 159 | +<p> |
| 160 | + Input (<code>M</code> = 2, <code>N</code> = 4, <code>K</code> = 4, <code>group_size</code> = 2): |
| 161 | +</p> |
| 162 | +<p> |
| 163 | + Activations \(x\) (float16, \(2 \times 4\)): |
| 164 | + \[ |
| 165 | + \begin{bmatrix} |
| 166 | + 1.0 & 0.0 & 1.0 & 0.0 \\ |
| 167 | + 0.0 & 1.0 & 0.0 & 1.0 |
| 168 | + \end{bmatrix} |
| 169 | + \] |
| 170 | + Packed weights \(w\_q\) (uint8, \(4 \times 2\)) with signed INT4 values in brackets: |
| 171 | + \[ |
| 172 | + \begin{bmatrix} |
| 173 | + \texttt{0x99} & \texttt{0x99} \\ |
| 174 | + \texttt{0xAA} & \texttt{0xAA} \\ |
| 175 | + \texttt{0x77} & \texttt{0x77} \\ |
| 176 | + \texttt{0x88} & \texttt{0x88} |
| 177 | + \end{bmatrix} |
| 178 | + \;\Rightarrow\; |
| 179 | + W_{\text{int4}} = |
| 180 | + \begin{bmatrix} |
| 181 | + 1 & 1 & 1 & 1 \\ |
| 182 | + 2 & 2 & 2 & 2 \\ |
| 183 | + -1 & -1 & -1 & -1 \\ |
| 184 | + 0 & 0 & 0 & 0 |
| 185 | + \end{bmatrix} |
| 186 | + \] |
| 187 | + Scales (float16, \(4 \times 2\), all entries 0.5): |
| 188 | + \[ |
| 189 | + \begin{bmatrix} |
| 190 | + 0.5 & 0.5 \\ |
| 191 | + 0.5 & 0.5 \\ |
| 192 | + 0.5 & 0.5 \\ |
| 193 | + 0.5 & 0.5 |
| 194 | + \end{bmatrix} |
| 195 | + \;\Rightarrow\; |
| 196 | + W_{\text{dequant}} = |
| 197 | + \begin{bmatrix} |
| 198 | + 0.5 & 0.5 & 0.5 & 0.5 \\ |
| 199 | + 1.0 & 1.0 & 1.0 & 1.0 \\ |
| 200 | + -0.5 & -0.5 & -0.5 & -0.5 \\ |
| 201 | + 0.0 & 0.0 & 0.0 & 0.0 |
| 202 | + \end{bmatrix} |
| 203 | + \] |
| 204 | + Output \(y = x \times W^T\) (float16, \(2 \times 4\)): |
| 205 | + \[ |
| 206 | + \begin{bmatrix} |
| 207 | + 1.0 & 2.0 & -1.0 & 0.0 \\ |
| 208 | + 1.0 & 2.0 & -1.0 & 0.0 |
| 209 | + \end{bmatrix} |
| 210 | + \] |
| 211 | +</p> |
| 212 | + |
| 213 | +<h2>Constraints</h2> |
| 214 | +<ul> |
| 215 | + <li>1 ≤ <code>M</code>, <code>N</code> ≤ 8,192</li> |
| 216 | + <li>1 ≤ <code>K</code> ≤ 8,192</li> |
| 217 | + <li><code>K</code> is divisible by <code>2</code> and by <code>group_size</code></li> |
| 218 | + <li><code>group_size</code> ∈ {2, 4, 8, 16, 32, 64, 128}</li> |
| 219 | + <li>All tensors are stored in row-major order</li> |
| 220 | + <li>Input dtype: <code>x</code> and <code>scales</code> are float16; <code>w_q</code> is uint8</li> |
| 221 | + <li>Output dtype: <code>y</code> is float16</li> |
| 222 | + <li>Performance is measured with <code>M</code> = 4,096, <code>N</code> = 4,096, <code>K</code> = 4,096, <code>group_size</code> = 128</li> |
| 223 | +</ul> |
0 commit comments