Skip to content

Commit 48b263f

Browse files
claude[bot]github-actions[bot]claudeshxjames
authored
Add challenge 81: INT4 Weight-Only Quantized MatMul (Medium) (#216)
* Add challenge 81: INT4 Weight-Only Quantized MatMul (Medium) Adds a W4A16 quantized matrix multiplication challenge modelling the core dequantization + GEMM kernel used in all modern LLM inference frameworks (AWQ, GPTQ, llama.cpp, vLLM). Solvers must unpack packed INT4 weights from uint8 bytes, apply group-wise float16 scales, and compute the mixed-precision matrix product against float16 activations. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Fix parameter description comments to include w_q in all starter files Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Improve INT4 matmul SVG visualization for clarity Redesign the diagram to clearly show the 3-step dequantization-to-matmul pipeline: unpack nibbles, group-wise scaling with concrete example (K=8, group_size=4), and final matmul. Fix text overflowing bounding boxes and remove diagonal arrow that intersected blocks. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: James Song <haoxijamessong@gmail.com>
1 parent 6a506e7 commit 48b263f

8 files changed

Lines changed: 469 additions & 0 deletions

File tree

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
<p>
2+
Implement a weight-only INT4 quantized matrix multiplication (W4A16), a core kernel used in
3+
modern LLM inference. Given a float16 activation matrix <code>x</code> of shape
4+
<code>M &times; K</code> and a weight matrix stored in packed INT4 format, compute the output
5+
matrix <code>y = x &times; W<sup>T</sup></code> of shape <code>M &times; N</code>, where
6+
<code>W</code> is the dequantized float16 weight matrix of shape <code>N &times; K</code>.
7+
</p>
8+
9+
<svg width="700" height="400" viewBox="0 0 700 400" xmlns="http://www.w3.org/2000/svg"
10+
style="display:block; margin:20px auto; font-family:monospace;">
11+
<rect width="700" height="400" fill="#222" rx="10"/>
12+
<defs>
13+
<marker id="arr" markerWidth="8" markerHeight="8" refX="6" refY="3" orient="auto">
14+
<path d="M0,0 L0,6 L8,3 z" fill="#aaa"/>
15+
</marker>
16+
</defs>
17+
18+
<!-- ============================================================ -->
19+
<!-- ROW 1: UNPACK — packed byte → two unsigned nibbles → signed -->
20+
<!-- ============================================================ -->
21+
<text x="18" y="20" fill="#666" font-size="10">STEP 1: UNPACK</text>
22+
23+
<!-- Packed byte -->
24+
<text x="80" y="48" fill="#ccc" font-size="11" text-anchor="middle">w_q[n, i]</text>
25+
<rect x="20" y="56" width="120" height="32" fill="#1a3a5c" rx="4" stroke="#4a9edd" stroke-width="1.5"/>
26+
<line x1="80" y1="56" x2="80" y2="88" stroke="#4a9edd" stroke-width="1" stroke-dasharray="3,2"/>
27+
<text x="50" y="77" fill="#4a9edd" font-size="10" text-anchor="middle">hi 7:4</text>
28+
<text x="110" y="77" fill="#7ec87e" font-size="10" text-anchor="middle">lo 3:0</text>
29+
30+
<!-- Arrow right -->
31+
<text x="160" y="77" fill="#aaa" font-size="14" text-anchor="middle">&#x2192;</text>
32+
33+
<!-- Unsigned nibbles -->
34+
<rect x="180" y="56" width="50" height="32" fill="#1a3a5c" rx="4" stroke="#4a9edd" stroke-width="1.5"/>
35+
<text x="205" y="77" fill="#4a9edd" font-size="10" text-anchor="middle">9</text>
36+
<rect x="236" y="56" width="50" height="32" fill="#1a4a1a" rx="4" stroke="#7ec87e" stroke-width="1.5"/>
37+
<text x="261" y="77" fill="#7ec87e" font-size="10" text-anchor="middle">10</text>
38+
39+
<!-- "- 8" arrow -->
40+
<text x="310" y="77" fill="#ccc" font-size="11" text-anchor="middle">&#x2212; 8</text>
41+
<text x="345" y="77" fill="#aaa" font-size="14" text-anchor="middle">&#x2192;</text>
42+
43+
<!-- Signed int4 -->
44+
<rect x="365" y="56" width="50" height="32" fill="#3a2a1a" rx="4" stroke="#e0a040" stroke-width="1.5"/>
45+
<text x="390" y="77" fill="#e0a040" font-size="10" text-anchor="middle">+1</text>
46+
<rect x="421" y="56" width="50" height="32" fill="#3a2a1a" rx="4" stroke="#e0a040" stroke-width="1.5"/>
47+
<text x="446" y="77" fill="#e0a040" font-size="10" text-anchor="middle">+2</text>
48+
49+
<text x="540" y="77" fill="#888" font-size="10" text-anchor="middle">signed int4 [&#x2212;8, 7]</text>
50+
51+
<!-- ============================================================ -->
52+
<!-- ROW 2: GROUP-WISE SCALING — show K=8, group_size=4 -->
53+
<!-- ============================================================ -->
54+
<text x="18" y="112" fill="#666" font-size="10">STEP 2: DEQUANTIZE (example: one row n, K=8, group_size=4)</text>
55+
56+
<!-- K-axis label -->
57+
<text x="350" y="136" fill="#888" font-size="10" text-anchor="middle">k &#x2192;</text>
58+
59+
<!-- Group 0 bracket + cells -->
60+
<text x="145" y="136" fill="#c060e0" font-size="9" text-anchor="middle">group 0: scale[n, 0]</text>
61+
<rect x="58" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
62+
<text x="81" y="161" fill="#e0a040" font-size="10" text-anchor="middle">+1</text>
63+
<rect x="108" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
64+
<text x="131" y="161" fill="#e0a040" font-size="10" text-anchor="middle">+2</text>
65+
<rect x="158" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
66+
<text x="181" y="161" fill="#e0a040" font-size="10" text-anchor="middle">&#x2212;1</text>
67+
<rect x="208" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
68+
<text x="231" y="161" fill="#e0a040" font-size="10" text-anchor="middle">+3</text>
69+
<!-- Group 0 bracket -->
70+
<rect x="56" y="140" width="200" height="32" rx="4" fill="none" stroke="#c060e0" stroke-width="1.5" stroke-dasharray="4,2"/>
71+
72+
<!-- Group 1 bracket + cells -->
73+
<text x="385" y="136" fill="#c060e0" font-size="9" text-anchor="middle">group 1: scale[n, 1]</text>
74+
<rect x="298" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
75+
<text x="321" y="161" fill="#e0a040" font-size="10" text-anchor="middle">0</text>
76+
<rect x="348" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
77+
<text x="371" y="161" fill="#e0a040" font-size="10" text-anchor="middle">&#x2212;3</text>
78+
<rect x="398" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
79+
<text x="421" y="161" fill="#e0a040" font-size="10" text-anchor="middle">+7</text>
80+
<rect x="448" y="142" width="46" height="28" fill="#3a2a1a" rx="3" stroke="#e0a040" stroke-width="1"/>
81+
<text x="471" y="161" fill="#e0a040" font-size="10" text-anchor="middle">&#x2212;2</text>
82+
<!-- Group 1 bracket -->
83+
<rect x="296" y="140" width="200" height="32" rx="4" fill="none" stroke="#c060e0" stroke-width="1.5" stroke-dasharray="4,2"/>
84+
85+
<!-- "int4" label on left -->
86+
<text x="30" y="161" fill="#e0a040" font-size="9">int4</text>
87+
88+
<!-- Multiply arrows down -->
89+
<text x="156" y="190" fill="#ccc" font-size="12" text-anchor="middle">&#xd7; scale[n, 0]</text>
90+
<text x="396" y="190" fill="#ccc" font-size="12" text-anchor="middle">&#xd7; scale[n, 1]</text>
91+
<line x1="156" y1="172" x2="156" y2="198" stroke="#aaa" stroke-width="1" stroke-dasharray="3,2"/>
92+
<line x1="396" y1="172" x2="396" y2="198" stroke="#aaa" stroke-width="1" stroke-dasharray="3,2"/>
93+
94+
<!-- Dequantized row -->
95+
<text x="30" y="217" fill="#40c080" font-size="9">fp16</text>
96+
<rect x="56" y="202" width="200" height="28" fill="#1a3a2a" rx="4" stroke="#40c080" stroke-width="1.5"/>
97+
<text x="156" y="221" fill="#40c080" font-size="10" text-anchor="middle">W[n, 0..3] float16</text>
98+
<rect x="296" y="202" width="200" height="28" fill="#1a3a2a" rx="4" stroke="#40c080" stroke-width="1.5"/>
99+
<text x="396" y="221" fill="#40c080" font-size="10" text-anchor="middle">W[n, 4..7] float16</text>
100+
101+
<!-- Formula -->
102+
<text x="275" y="252" fill="#ccc" font-size="10" text-anchor="middle">W[n, k] = (nibble &#x2212; 8) &#xd7; scales[n, k // group_size]</text>
103+
104+
<!-- ============================================================ -->
105+
<!-- ROW 3: MATMUL -->
106+
<!-- ============================================================ -->
107+
<text x="18" y="280" fill="#666" font-size="10">STEP 3: MATMUL</text>
108+
109+
<!-- x box -->
110+
<rect x="60" y="296" width="80" height="60" fill="#1a3a5c" rx="4" stroke="#4a9edd" stroke-width="1.5"/>
111+
<text x="100" y="322" fill="#4a9edd" font-size="10" text-anchor="middle">x [M&#xd7;K]</text>
112+
<text x="100" y="340" fill="#4a9edd" font-size="9" text-anchor="middle">float16</text>
113+
114+
<!-- multiply sign -->
115+
<text x="162" y="330" fill="#ccc" font-size="16" text-anchor="middle">&#xd7;</text>
116+
117+
<!-- W^T box -->
118+
<rect x="185" y="296" width="100" height="60" fill="#1a3a2a" rx="4" stroke="#40c080" stroke-width="1.5"/>
119+
<text x="235" y="322" fill="#40c080" font-size="10" text-anchor="middle">W&#x1d40; [K&#xd7;N]</text>
120+
<text x="235" y="340" fill="#40c080" font-size="9" text-anchor="middle">float16</text>
121+
122+
<!-- equals sign -->
123+
<text x="310" y="330" fill="#ccc" font-size="16" text-anchor="middle">=</text>
124+
125+
<!-- y output box -->
126+
<rect x="335" y="296" width="90" height="60" fill="#3a1a1a" rx="4" stroke="#e05050" stroke-width="1.5"/>
127+
<text x="380" y="322" fill="#e05050" font-size="10" text-anchor="middle">y [M&#xd7;N]</text>
128+
<text x="380" y="340" fill="#e05050" font-size="9" text-anchor="middle">float16</text>
129+
130+
<!-- Arrow from dequant to W^T -->
131+
<line x1="235" y1="240" x2="235" y2="294" stroke="#40c080" stroke-width="1.5" stroke-dasharray="4,2" marker-end="url(#arr)"/>
132+
<text x="260" y="270" fill="#40c080" font-size="9">dequantized</text>
133+
</svg>
134+
135+
<p>
136+
<strong>Packing format:</strong> Each byte of <code>w_q</code> stores two INT4 weights. The
137+
high nibble (bits 7&ndash;4) holds weight <code>w[n, 2i]</code> and the low nibble (bits
138+
3&ndash;0) holds <code>w[n, 2i+1]</code>. INT4 values are stored unsigned in the range
139+
[0,&nbsp;15] with an offset of 8, so the signed weight is <code>nibble&nbsp;&minus;&nbsp;8</code>,
140+
giving values in [&minus;8,&nbsp;7].
141+
</p>
142+
143+
<p>
144+
<strong>Dequantization:</strong> Weights are dequantized group-wise. Each contiguous block of
145+
<code>group_size</code> weights along the <code>K</code> dimension shares one float16 scale:
146+
</p>
147+
<pre>
148+
W[n, k] = (w_q_nibble[n, k] - 8) * scales[n, k // group_size]
149+
</pre>
150+
151+
<h2>Implementation Requirements</h2>
152+
<ul>
153+
<li>Use only native features (external libraries are not permitted)</li>
154+
<li>The <code>solve</code> function signature must remain unchanged</li>
155+
<li>The final result must be stored in <code>y</code></li>
156+
</ul>
157+
158+
<h2>Example</h2>
159+
<p>
160+
Input (<code>M</code> = 2, <code>N</code> = 4, <code>K</code> = 4, <code>group_size</code> = 2):
161+
</p>
162+
<p>
163+
Activations \(x\) (float16, \(2 \times 4\)):
164+
\[
165+
\begin{bmatrix}
166+
1.0 & 0.0 & 1.0 & 0.0 \\
167+
0.0 & 1.0 & 0.0 & 1.0
168+
\end{bmatrix}
169+
\]
170+
Packed weights \(w\_q\) (uint8, \(4 \times 2\)) with signed INT4 values in brackets:
171+
\[
172+
\begin{bmatrix}
173+
\texttt{0x99} & \texttt{0x99} \\
174+
\texttt{0xAA} & \texttt{0xAA} \\
175+
\texttt{0x77} & \texttt{0x77} \\
176+
\texttt{0x88} & \texttt{0x88}
177+
\end{bmatrix}
178+
\;\Rightarrow\;
179+
W_{\text{int4}} =
180+
\begin{bmatrix}
181+
1 & 1 & 1 & 1 \\
182+
2 & 2 & 2 & 2 \\
183+
-1 & -1 & -1 & -1 \\
184+
0 & 0 & 0 & 0
185+
\end{bmatrix}
186+
\]
187+
Scales (float16, \(4 \times 2\), all entries 0.5):
188+
\[
189+
\begin{bmatrix}
190+
0.5 & 0.5 \\
191+
0.5 & 0.5 \\
192+
0.5 & 0.5 \\
193+
0.5 & 0.5
194+
\end{bmatrix}
195+
\;\Rightarrow\;
196+
W_{\text{dequant}} =
197+
\begin{bmatrix}
198+
0.5 & 0.5 & 0.5 & 0.5 \\
199+
1.0 & 1.0 & 1.0 & 1.0 \\
200+
-0.5 & -0.5 & -0.5 & -0.5 \\
201+
0.0 & 0.0 & 0.0 & 0.0
202+
\end{bmatrix}
203+
\]
204+
Output \(y = x \times W^T\) (float16, \(2 \times 4\)):
205+
\[
206+
\begin{bmatrix}
207+
1.0 & 2.0 & -1.0 & 0.0 \\
208+
1.0 & 2.0 & -1.0 & 0.0
209+
\end{bmatrix}
210+
\]
211+
</p>
212+
213+
<h2>Constraints</h2>
214+
<ul>
215+
<li>1 &le; <code>M</code>, <code>N</code> &le; 8,192</li>
216+
<li>1 &le; <code>K</code> &le; 8,192</li>
217+
<li><code>K</code> is divisible by <code>2</code> and by <code>group_size</code></li>
218+
<li><code>group_size</code> &isin; {2, 4, 8, 16, 32, 64, 128}</li>
219+
<li>All tensors are stored in row-major order</li>
220+
<li>Input dtype: <code>x</code> and <code>scales</code> are float16; <code>w_q</code> is uint8</li>
221+
<li>Output dtype: <code>y</code> is float16</li>
222+
<li>Performance is measured with <code>M</code> = 4,096, <code>N</code> = 4,096, <code>K</code> = 4,096, <code>group_size</code> = 128</li>
223+
</ul>
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import ctypes
2+
from typing import Any, Dict, List
3+
4+
import torch
5+
from core.challenge_base import ChallengeBase
6+
7+
8+
class Challenge(ChallengeBase):
9+
def __init__(self):
10+
super().__init__(
11+
name="INT4 Weight-Only Quantized MatMul",
12+
atol=1e-02,
13+
rtol=1e-02,
14+
num_gpus=1,
15+
access_tier="free",
16+
)
17+
18+
def reference_impl(
19+
self,
20+
x: torch.Tensor,
21+
w_q: torch.Tensor,
22+
scales: torch.Tensor,
23+
y: torch.Tensor,
24+
M: int,
25+
N: int,
26+
K: int,
27+
group_size: int,
28+
):
29+
assert x.shape == (M, K)
30+
assert w_q.shape == (N, K // 2)
31+
assert scales.shape == (N, K // group_size)
32+
assert y.shape == (M, N)
33+
assert x.dtype == torch.float16
34+
assert w_q.dtype == torch.uint8
35+
assert scales.dtype == torch.float16
36+
assert y.dtype == torch.float16
37+
assert x.device.type == "cuda"
38+
assert w_q.device.type == "cuda"
39+
assert scales.device.type == "cuda"
40+
assert y.device.type == "cuda"
41+
42+
# Unpack INT4 weights from packed uint8 bytes.
43+
# w_q[n, i] stores two weights: w[n, 2*i] in the high nibble (bits 7:4)
44+
# and w[n, 2*i+1] in the low nibble (bits 3:0).
45+
# INT4 values are stored unsigned (0–15) with an offset of 8,
46+
# so the signed value is nibble - 8, giving range [-8, 7].
47+
w_high = ((w_q >> 4) & 0xF).to(torch.int32) - 8 # [N, K//2]
48+
w_low = (w_q & 0xF).to(torch.int32) - 8 # [N, K//2]
49+
50+
# Interleave high and low nibbles to reconstruct [N, K]
51+
w_int = torch.stack([w_high, w_low], dim=-1).reshape(N, K) # [N, K]
52+
53+
# Apply group-wise scales: dequantize each group
54+
n_groups = K // group_size
55+
w_groups = w_int.reshape(N, n_groups, group_size).float() # [N, n_groups, group_size]
56+
scales_f = scales.float().unsqueeze(-1) # [N, n_groups, 1]
57+
w_dequant = (w_groups * scales_f).reshape(N, K) # [N, K]
58+
59+
# MatMul: x [M, K] @ w_dequant.T [K, N] = y [M, N]
60+
y.copy_((x.float() @ w_dequant.T).half())
61+
62+
def get_solve_signature(self) -> Dict[str, tuple]:
63+
return {
64+
"x": (ctypes.POINTER(ctypes.c_uint16), "in"),
65+
"w_q": (ctypes.POINTER(ctypes.c_uint8), "in"),
66+
"scales": (ctypes.POINTER(ctypes.c_uint16), "in"),
67+
"y": (ctypes.POINTER(ctypes.c_uint16), "out"),
68+
"M": (ctypes.c_int, "in"),
69+
"N": (ctypes.c_int, "in"),
70+
"K": (ctypes.c_int, "in"),
71+
"group_size": (ctypes.c_int, "in"),
72+
}
73+
74+
def _make_test_case(self, M: int, N: int, K: int, group_size: int, zero_x: bool = False):
75+
device = "cuda"
76+
if zero_x:
77+
x = torch.zeros(M, K, device=device, dtype=torch.float16)
78+
else:
79+
x = torch.randn(M, K, device=device, dtype=torch.float16)
80+
# Random packed INT4 weights: each byte holds two nibbles in [0,15]
81+
w_q = torch.randint(0, 256, (N, K // 2), dtype=torch.uint8, device=device)
82+
# Small positive scales to keep magnitudes reasonable
83+
scales = torch.rand(N, K // group_size, device=device, dtype=torch.float16) * 0.1 + 0.01
84+
y = torch.empty(M, N, device=device, dtype=torch.float16)
85+
return {
86+
"x": x,
87+
"w_q": w_q,
88+
"scales": scales,
89+
"y": y,
90+
"M": M,
91+
"N": N,
92+
"K": K,
93+
"group_size": group_size,
94+
}
95+
96+
def generate_example_test(self) -> Dict[str, Any]:
97+
device = "cuda"
98+
M, N, K, group_size = 2, 4, 4, 2
99+
100+
x = torch.tensor(
101+
[[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]],
102+
device=device,
103+
dtype=torch.float16,
104+
)
105+
# Packed INT4 weights (high nibble first).
106+
# Row 0: weights [1,1,1,1] → nibbles stored as [9,9,9,9] → bytes [0x99, 0x99] = [153, 153]
107+
# Row 1: weights [2,2,2,2] → nibbles [10,10,10,10] → bytes [0xAA, 0xAA] = [170, 170]
108+
# Row 2: weights [-1,-1,-1,-1] → nibbles [7,7,7,7] → bytes [0x77, 0x77] = [119, 119]
109+
# Row 3: weights [0,0,0,0] → nibbles [8,8,8,8] → bytes [0x88, 0x88] = [136, 136]
110+
w_q = torch.tensor(
111+
[[153, 153], [170, 170], [119, 119], [136, 136]],
112+
dtype=torch.uint8,
113+
device=device,
114+
)
115+
# One scale per group (group_size=2 → 2 groups per row), all 0.5
116+
scales = torch.full((N, K // group_size), 0.5, device=device, dtype=torch.float16)
117+
y = torch.empty(M, N, device=device, dtype=torch.float16)
118+
119+
return {
120+
"x": x,
121+
"w_q": w_q,
122+
"scales": scales,
123+
"y": y,
124+
"M": M,
125+
"N": N,
126+
"K": K,
127+
"group_size": group_size,
128+
}
129+
130+
def generate_functional_test(self) -> List[Dict[str, Any]]:
131+
torch.manual_seed(42)
132+
tests = []
133+
134+
# Edge cases — tiny K, small group_size
135+
tests.append(self._make_test_case(1, 2, 4, 2, zero_x=True))
136+
tests.append(self._make_test_case(2, 4, 4, 2))
137+
tests.append(self._make_test_case(3, 5, 8, 4))
138+
139+
# Power-of-2 sizes
140+
tests.append(self._make_test_case(16, 16, 32, 16))
141+
tests.append(self._make_test_case(32, 64, 64, 32))
142+
tests.append(self._make_test_case(64, 128, 128, 64))
143+
144+
# Non-power-of-2 sizes
145+
tests.append(self._make_test_case(30, 50, 64, 32))
146+
tests.append(self._make_test_case(100, 200, 128, 64))
147+
tests.append(self._make_test_case(255, 100, 128, 64))
148+
149+
# Realistic LLM inference sizes
150+
tests.append(self._make_test_case(128, 256, 512, 128))
151+
152+
return tests
153+
154+
def generate_performance_test(self) -> Dict[str, Any]:
155+
torch.manual_seed(0)
156+
# Typical LLM weight matrix: 4096×4096 with group_size=128
157+
return self._make_test_case(4096, 4096, 4096, 128)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#include <cuda_fp16.h>
2+
#include <cuda_runtime.h>
3+
#include <stdint.h>
4+
5+
// x, w_q, scales, y are device pointers
6+
extern "C" void solve(const __half* x, const uint8_t* w_q, const __half* scales, __half* y, int M,
7+
int N, int K, int group_size) {}

0 commit comments

Comments
 (0)