Skip to content

Commit e099a71

Browse files
Add FP4 matmul hard challenge
Weight-only FP4 E2M1 quantized matmul (W4A16) with group-wise FP16 scales, the kernel powering low-precision LLM inference on Hopper and Blackwell. Two FP4 values are packed per uint8 byte; each contiguous block of group_size weights along K shares one scale. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ca891d3 commit e099a71

8 files changed

Lines changed: 377 additions & 0 deletions

File tree

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
<p>
2+
Implement an FP4 weight-only quantized matrix multiplication, the kernel at the heart of
3+
modern low-precision LLM inference on Hopper and Blackwell GPUs. Given a float16 activation
4+
matrix <code>x</code> of shape <code>M &times; K</code> and a weight matrix stored in packed
5+
FP4 E2M1 format, compute <code>y = x &times; W<sup>T</sup></code> of shape
6+
<code>M &times; N</code>, where <code>W</code> is the dequantized float16 weight matrix of
7+
shape <code>N &times; K</code>.
8+
</p>
9+
10+
<p>
11+
<strong>FP4 E2M1 format:</strong> Each weight is encoded in 4 bits as
12+
[sign | exponent (2 bits) | mantissa (1 bit)], representing one of sixteen values:
13+
<code>{&plusmn;0, &plusmn;0.5, &plusmn;1, &plusmn;1.5, &plusmn;2, &plusmn;3, &plusmn;4, &plusmn;6}</code>.
14+
The nibble-to-value mapping is:
15+
</p>
16+
<pre>
17+
0x0 = 0.0 0x8 = -0.0
18+
0x1 = 0.5 0x9 = -0.5
19+
0x2 = 1.0 0xA = -1.0
20+
0x3 = 1.5 0xB = -1.5
21+
0x4 = 2.0 0xC = -2.0
22+
0x5 = 3.0 0xD = -3.0
23+
0x6 = 4.0 0xE = -4.0
24+
0x7 = 6.0 0xF = -6.0
25+
</pre>
26+
27+
<p>
28+
<strong>Packing:</strong> Each byte of <code>w_q</code> stores two FP4 weights. The high
29+
nibble (bits 7&ndash;4) holds <code>w[n, 2i]</code> and the low nibble (bits 3&ndash;0) holds
30+
<code>w[n, 2i+1]</code>.
31+
</p>
32+
33+
<p>
34+
<strong>Dequantization:</strong> Weights are dequantized group-wise. Each contiguous block of
35+
<code>group_size</code> weights along the <code>K</code> dimension shares one float16 scale:
36+
</p>
37+
<pre>
38+
W[n, k] = fp4_decode(w_q_nibble[n, k]) * scales[n, k // group_size]
39+
</pre>
40+
41+
<h2>Implementation Requirements</h2>
42+
<ul>
43+
<li>Use only native features (external libraries are not permitted)</li>
44+
<li>The <code>solve</code> function signature must remain unchanged</li>
45+
<li>The final result must be stored in <code>y</code></li>
46+
</ul>
47+
48+
<h2>Example</h2>
49+
<p>
50+
Input (<code>M</code> = 2, <code>N</code> = 4, <code>K</code> = 4, <code>group_size</code> = 2):
51+
</p>
52+
<p>
53+
Activations \(x\) (float16, \(2 \times 4\)):
54+
\[
55+
\begin{bmatrix}
56+
1.0 & 0.0 & 1.0 & 0.0 \\
57+
0.0 & 1.0 & 0.0 & 1.0
58+
\end{bmatrix}
59+
\]
60+
Packed weights \(w\_q\) (uint8, \(4 \times 2\)) decoded via the FP4 E2M1 table:
61+
\[
62+
\begin{bmatrix}
63+
\texttt{0x22} & \texttt{0x22} \\
64+
\texttt{0x44} & \texttt{0x44} \\
65+
\texttt{0xAA} & \texttt{0xAA} \\
66+
\texttt{0x00} & \texttt{0x00}
67+
\end{bmatrix}
68+
\;\Rightarrow\;
69+
W_{\text{fp4}} =
70+
\begin{bmatrix}
71+
1.0 & 1.0 & 1.0 & 1.0 \\
72+
2.0 & 2.0 & 2.0 & 2.0 \\
73+
-1.0 & -1.0 & -1.0 & -1.0 \\
74+
0.0 & 0.0 & 0.0 & 0.0
75+
\end{bmatrix}
76+
\]
77+
Scales (float16, \(4 \times 2\), all entries 0.5):
78+
\[
79+
\begin{bmatrix}
80+
0.5 & 0.5 \\
81+
0.5 & 0.5 \\
82+
0.5 & 0.5 \\
83+
0.5 & 0.5
84+
\end{bmatrix}
85+
\;\Rightarrow\;
86+
W_{\text{dequant}} =
87+
\begin{bmatrix}
88+
0.5 & 0.5 & 0.5 & 0.5 \\
89+
1.0 & 1.0 & 1.0 & 1.0 \\
90+
-0.5 & -0.5 & -0.5 & -0.5 \\
91+
0.0 & 0.0 & 0.0 & 0.0
92+
\end{bmatrix}
93+
\]
94+
Output \(y = x \times W^T\) (float16, \(2 \times 4\)):
95+
\[
96+
\begin{bmatrix}
97+
1.0 & 2.0 & -1.0 & 0.0 \\
98+
1.0 & 2.0 & -1.0 & 0.0
99+
\end{bmatrix}
100+
\]
101+
</p>
102+
103+
<h2>Constraints</h2>
104+
<ul>
105+
<li>1 &le; <code>M</code>, <code>N</code> &le; 8,192</li>
106+
<li>1 &le; <code>K</code> &le; 8,192</li>
107+
<li><code>K</code> is divisible by <code>2</code> and by <code>group_size</code></li>
108+
<li><code>group_size</code> &isin; {2, 4, 8, 16, 32}</li>
109+
<li>All tensors are stored in row-major order</li>
110+
<li>Input dtype: <code>x</code> and <code>scales</code> are float16; <code>w_q</code> is uint8</li>
111+
<li>Output dtype: <code>y</code> is float16</li>
112+
<li>Performance is measured with <code>M</code> = 2,048, <code>N</code> = 8,192, <code>K</code> = 3,072, <code>group_size</code> = 32</li>
113+
</ul>
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import ctypes
2+
from typing import Any, Dict, List
3+
4+
import torch
5+
from core.challenge_base import ChallengeBase
6+
7+
# OCP FP4 E2M1 lookup table: 4-bit unsigned index -> float value.
8+
# Bit layout: [sign | exp1 exp0 | mantissa]. Sixteen representable values.
9+
FP4_E2M1_TABLE = [
10+
0.0,
11+
0.5,
12+
1.0,
13+
1.5,
14+
2.0,
15+
3.0,
16+
4.0,
17+
6.0,
18+
-0.0,
19+
-0.5,
20+
-1.0,
21+
-1.5,
22+
-2.0,
23+
-3.0,
24+
-4.0,
25+
-6.0,
26+
]
27+
28+
29+
class Challenge(ChallengeBase):
30+
def __init__(self):
31+
super().__init__(
32+
name="FP4 MatMul",
33+
atol=5e-02,
34+
rtol=5e-02,
35+
num_gpus=1,
36+
access_tier="free",
37+
)
38+
39+
def reference_impl(
40+
self,
41+
x: torch.Tensor,
42+
w_q: torch.Tensor,
43+
scales: torch.Tensor,
44+
y: torch.Tensor,
45+
M: int,
46+
N: int,
47+
K: int,
48+
group_size: int,
49+
):
50+
assert x.shape == (M, K)
51+
assert w_q.shape == (N, K // 2)
52+
assert scales.shape == (N, K // group_size)
53+
assert y.shape == (M, N)
54+
assert x.dtype == torch.float16
55+
assert w_q.dtype == torch.uint8
56+
assert scales.dtype == torch.float16
57+
assert y.dtype == torch.float16
58+
assert x.device.type == "cuda"
59+
assert w_q.device.type == "cuda"
60+
assert scales.device.type == "cuda"
61+
assert y.device.type == "cuda"
62+
63+
# Decode packed FP4 E2M1 nibbles via lookup table.
64+
# w_q[n, i] holds two FP4 values: w[n, 2*i] in the high nibble (bits 7:4)
65+
# and w[n, 2*i+1] in the low nibble (bits 3:0).
66+
table = torch.tensor(FP4_E2M1_TABLE, device=x.device, dtype=torch.float32)
67+
high = ((w_q >> 4) & 0xF).to(torch.long) # [N, K//2]
68+
low = (w_q & 0xF).to(torch.long) # [N, K//2]
69+
w_high = table[high] # [N, K//2]
70+
w_low = table[low] # [N, K//2]
71+
w_fp4 = torch.stack([w_high, w_low], dim=-1).reshape(N, K) # [N, K]
72+
73+
# Apply group-wise FP16 scales: each contiguous block of `group_size`
74+
# weights along K shares one scale.
75+
n_groups = K // group_size
76+
w_groups = w_fp4.reshape(N, n_groups, group_size)
77+
scales_f = scales.float().unsqueeze(-1) # [N, n_groups, 1]
78+
w_dequant = (w_groups * scales_f).reshape(N, K)
79+
80+
y.copy_((x.float() @ w_dequant.T).half())
81+
82+
def get_solve_signature(self) -> Dict[str, tuple]:
83+
return {
84+
"x": (ctypes.POINTER(ctypes.c_uint16), "in"),
85+
"w_q": (ctypes.POINTER(ctypes.c_uint8), "in"),
86+
"scales": (ctypes.POINTER(ctypes.c_uint16), "in"),
87+
"y": (ctypes.POINTER(ctypes.c_uint16), "out"),
88+
"M": (ctypes.c_int, "in"),
89+
"N": (ctypes.c_int, "in"),
90+
"K": (ctypes.c_int, "in"),
91+
"group_size": (ctypes.c_int, "in"),
92+
}
93+
94+
def _make_test_case(self, M: int, N: int, K: int, group_size: int, zero_x: bool = False):
95+
device = "cuda"
96+
if zero_x:
97+
x = torch.zeros(M, K, device=device, dtype=torch.float16)
98+
else:
99+
x = torch.randn(M, K, device=device, dtype=torch.float16) * 0.5
100+
w_q = torch.randint(0, 256, (N, K // 2), dtype=torch.uint8, device=device)
101+
scales = torch.rand(N, K // group_size, device=device, dtype=torch.float16) * 0.1 + 0.01
102+
y = torch.empty(M, N, device=device, dtype=torch.float16)
103+
return {
104+
"x": x,
105+
"w_q": w_q,
106+
"scales": scales,
107+
"y": y,
108+
"M": M,
109+
"N": N,
110+
"K": K,
111+
"group_size": group_size,
112+
}
113+
114+
def generate_example_test(self) -> Dict[str, Any]:
115+
device = "cuda"
116+
M, N, K, group_size = 2, 4, 4, 2
117+
118+
x = torch.tensor(
119+
[[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]],
120+
device=device,
121+
dtype=torch.float16,
122+
)
123+
# Packed FP4 E2M1 weights (high nibble first).
124+
# Row 0: FP4 [1.0,1.0,1.0,1.0] -> nibbles [0x2,0x2,0x2,0x2] -> bytes [0x22,0x22] = [34,34]
125+
# Row 1: FP4 [2.0,2.0,2.0,2.0] -> nibbles [0x4,0x4,0x4,0x4] -> bytes [0x44,0x44] = [68,68]
126+
# Row 2: FP4 [-1,-1,-1,-1] -> nibbles [0xA,0xA,0xA,0xA] -> bytes [0xAA,0xAA] = [170,170]
127+
# Row 3: FP4 [0.0,0.0,0.0,0.0] -> nibbles [0x0,0x0,0x0,0x0] -> bytes [0x00,0x00] = [0,0]
128+
w_q = torch.tensor(
129+
[[34, 34], [68, 68], [170, 170], [0, 0]],
130+
dtype=torch.uint8,
131+
device=device,
132+
)
133+
scales = torch.full((N, K // group_size), 0.5, device=device, dtype=torch.float16)
134+
y = torch.empty(M, N, device=device, dtype=torch.float16)
135+
136+
return {
137+
"x": x,
138+
"w_q": w_q,
139+
"scales": scales,
140+
"y": y,
141+
"M": M,
142+
"N": N,
143+
"K": K,
144+
"group_size": group_size,
145+
}
146+
147+
def generate_functional_test(self) -> List[Dict[str, Any]]:
148+
torch.manual_seed(42)
149+
tests = []
150+
151+
# Edge cases with tiny shapes.
152+
tests.append(self._make_test_case(1, 2, 4, 2, zero_x=True))
153+
tests.append(self._make_test_case(2, 4, 4, 2))
154+
tests.append(self._make_test_case(3, 5, 8, 4))
155+
156+
# Power-of-2 shapes.
157+
tests.append(self._make_test_case(16, 16, 32, 16))
158+
tests.append(self._make_test_case(32, 64, 64, 32))
159+
tests.append(self._make_test_case(128, 128, 256, 32))
160+
161+
# Non-power-of-2 shapes.
162+
tests.append(self._make_test_case(30, 50, 64, 32))
163+
tests.append(self._make_test_case(100, 200, 128, 32))
164+
tests.append(self._make_test_case(255, 100, 128, 32))
165+
166+
# Realistic LLM inference shape.
167+
tests.append(self._make_test_case(512, 1024, 1024, 32))
168+
169+
return tests
170+
171+
def generate_performance_test(self) -> Dict[str, Any]:
172+
torch.manual_seed(0)
173+
# Matches the FP4 matmul shapes reported in AutoKernel community results.
174+
return self._make_test_case(2048, 8192, 3072, 32)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#include <cuda_fp16.h>
2+
#include <cuda_runtime.h>
3+
#include <stdint.h>
4+
5+
// x, w_q, scales, y are device pointers
6+
extern "C" void solve(const __half* x, const uint8_t* w_q, const __half* scales, __half* y, int M,
7+
int N, int K, int group_size) {}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import cutlass
2+
import cutlass.cute as cute
3+
4+
5+
# x, w_q, scales, y are tensors on the GPU
6+
@cute.jit
7+
def solve(
8+
x: cute.Tensor,
9+
w_q: cute.Tensor,
10+
scales: cute.Tensor,
11+
y: cute.Tensor,
12+
M: cute.Int32,
13+
N: cute.Int32,
14+
K: cute.Int32,
15+
group_size: cute.Int32,
16+
):
17+
pass
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
5+
# x, w_q, scales are tensors on GPU
6+
@jax.jit
7+
def solve(
8+
x: jax.Array,
9+
w_q: jax.Array,
10+
scales: jax.Array,
11+
M: int,
12+
N: int,
13+
K: int,
14+
group_size: int,
15+
) -> jax.Array:
16+
# return output tensor directly
17+
pass
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from std.gpu.host import DeviceContext
2+
from std.memory import UnsafePointer
3+
4+
5+
# x, w_q, scales, y are device pointers
6+
@export
7+
def solve(
8+
x: UnsafePointer[Float16, MutExternalOrigin],
9+
w_q: UnsafePointer[UInt8, MutExternalOrigin],
10+
scales: UnsafePointer[Float16, MutExternalOrigin],
11+
y: UnsafePointer[Float16, MutExternalOrigin],
12+
M: Int32,
13+
N: Int32,
14+
K: Int32,
15+
group_size: Int32,
16+
) raises:
17+
pass
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
3+
4+
# x, w_q, scales, y are tensors on the GPU
5+
def solve(
6+
x: torch.Tensor,
7+
w_q: torch.Tensor,
8+
scales: torch.Tensor,
9+
y: torch.Tensor,
10+
M: int,
11+
N: int,
12+
K: int,
13+
group_size: int,
14+
):
15+
pass
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
# x, w_q, scales, y are tensors on the GPU
7+
def solve(
8+
x: torch.Tensor,
9+
w_q: torch.Tensor,
10+
scales: torch.Tensor,
11+
y: torch.Tensor,
12+
M: int,
13+
N: int,
14+
K: int,
15+
group_size: int,
16+
):
17+
pass

0 commit comments

Comments
 (0)