|
| 1 | +import ctypes |
| 2 | +from typing import Any, Dict, List |
| 3 | + |
| 4 | +import torch |
| 5 | +from core.challenge_base import ChallengeBase |
| 6 | + |
| 7 | +# OCP FP4 E2M1 lookup table: 4-bit unsigned index -> float value. |
| 8 | +# Bit layout: [sign | exp1 exp0 | mantissa]. Sixteen representable values. |
| 9 | +FP4_E2M1_TABLE = [ |
| 10 | + 0.0, |
| 11 | + 0.5, |
| 12 | + 1.0, |
| 13 | + 1.5, |
| 14 | + 2.0, |
| 15 | + 3.0, |
| 16 | + 4.0, |
| 17 | + 6.0, |
| 18 | + -0.0, |
| 19 | + -0.5, |
| 20 | + -1.0, |
| 21 | + -1.5, |
| 22 | + -2.0, |
| 23 | + -3.0, |
| 24 | + -4.0, |
| 25 | + -6.0, |
| 26 | +] |
| 27 | + |
| 28 | + |
| 29 | +class Challenge(ChallengeBase): |
| 30 | + def __init__(self): |
| 31 | + super().__init__( |
| 32 | + name="FP4 MatMul", |
| 33 | + atol=5e-02, |
| 34 | + rtol=5e-02, |
| 35 | + num_gpus=1, |
| 36 | + access_tier="free", |
| 37 | + ) |
| 38 | + |
| 39 | + def reference_impl( |
| 40 | + self, |
| 41 | + x: torch.Tensor, |
| 42 | + w_q: torch.Tensor, |
| 43 | + scales: torch.Tensor, |
| 44 | + y: torch.Tensor, |
| 45 | + M: int, |
| 46 | + N: int, |
| 47 | + K: int, |
| 48 | + group_size: int, |
| 49 | + ): |
| 50 | + assert x.shape == (M, K) |
| 51 | + assert w_q.shape == (N, K // 2) |
| 52 | + assert scales.shape == (N, K // group_size) |
| 53 | + assert y.shape == (M, N) |
| 54 | + assert x.dtype == torch.float16 |
| 55 | + assert w_q.dtype == torch.uint8 |
| 56 | + assert scales.dtype == torch.float16 |
| 57 | + assert y.dtype == torch.float16 |
| 58 | + assert x.device.type == "cuda" |
| 59 | + assert w_q.device.type == "cuda" |
| 60 | + assert scales.device.type == "cuda" |
| 61 | + assert y.device.type == "cuda" |
| 62 | + |
| 63 | + # Decode packed FP4 E2M1 nibbles via lookup table. |
| 64 | + # w_q[n, i] holds two FP4 values: w[n, 2*i] in the high nibble (bits 7:4) |
| 65 | + # and w[n, 2*i+1] in the low nibble (bits 3:0). |
| 66 | + table = torch.tensor(FP4_E2M1_TABLE, device=x.device, dtype=torch.float32) |
| 67 | + high = ((w_q >> 4) & 0xF).to(torch.long) # [N, K//2] |
| 68 | + low = (w_q & 0xF).to(torch.long) # [N, K//2] |
| 69 | + w_high = table[high] # [N, K//2] |
| 70 | + w_low = table[low] # [N, K//2] |
| 71 | + w_fp4 = torch.stack([w_high, w_low], dim=-1).reshape(N, K) # [N, K] |
| 72 | + |
| 73 | + # Apply group-wise FP16 scales: each contiguous block of `group_size` |
| 74 | + # weights along K shares one scale. |
| 75 | + n_groups = K // group_size |
| 76 | + w_groups = w_fp4.reshape(N, n_groups, group_size) |
| 77 | + scales_f = scales.float().unsqueeze(-1) # [N, n_groups, 1] |
| 78 | + w_dequant = (w_groups * scales_f).reshape(N, K) |
| 79 | + |
| 80 | + y.copy_((x.float() @ w_dequant.T).half()) |
| 81 | + |
| 82 | + def get_solve_signature(self) -> Dict[str, tuple]: |
| 83 | + return { |
| 84 | + "x": (ctypes.POINTER(ctypes.c_uint16), "in"), |
| 85 | + "w_q": (ctypes.POINTER(ctypes.c_uint8), "in"), |
| 86 | + "scales": (ctypes.POINTER(ctypes.c_uint16), "in"), |
| 87 | + "y": (ctypes.POINTER(ctypes.c_uint16), "out"), |
| 88 | + "M": (ctypes.c_int, "in"), |
| 89 | + "N": (ctypes.c_int, "in"), |
| 90 | + "K": (ctypes.c_int, "in"), |
| 91 | + "group_size": (ctypes.c_int, "in"), |
| 92 | + } |
| 93 | + |
| 94 | + def _make_test_case(self, M: int, N: int, K: int, group_size: int, zero_x: bool = False): |
| 95 | + device = "cuda" |
| 96 | + if zero_x: |
| 97 | + x = torch.zeros(M, K, device=device, dtype=torch.float16) |
| 98 | + else: |
| 99 | + x = torch.randn(M, K, device=device, dtype=torch.float16) * 0.5 |
| 100 | + w_q = torch.randint(0, 256, (N, K // 2), dtype=torch.uint8, device=device) |
| 101 | + scales = torch.rand(N, K // group_size, device=device, dtype=torch.float16) * 0.1 + 0.01 |
| 102 | + y = torch.empty(M, N, device=device, dtype=torch.float16) |
| 103 | + return { |
| 104 | + "x": x, |
| 105 | + "w_q": w_q, |
| 106 | + "scales": scales, |
| 107 | + "y": y, |
| 108 | + "M": M, |
| 109 | + "N": N, |
| 110 | + "K": K, |
| 111 | + "group_size": group_size, |
| 112 | + } |
| 113 | + |
| 114 | + def generate_example_test(self) -> Dict[str, Any]: |
| 115 | + device = "cuda" |
| 116 | + M, N, K, group_size = 2, 4, 4, 2 |
| 117 | + |
| 118 | + x = torch.tensor( |
| 119 | + [[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]], |
| 120 | + device=device, |
| 121 | + dtype=torch.float16, |
| 122 | + ) |
| 123 | + # Packed FP4 E2M1 weights (high nibble first). |
| 124 | + # Row 0: FP4 [1.0,1.0,1.0,1.0] -> nibbles [0x2,0x2,0x2,0x2] -> bytes [0x22,0x22] = [34,34] |
| 125 | + # Row 1: FP4 [2.0,2.0,2.0,2.0] -> nibbles [0x4,0x4,0x4,0x4] -> bytes [0x44,0x44] = [68,68] |
| 126 | + # Row 2: FP4 [-1,-1,-1,-1] -> nibbles [0xA,0xA,0xA,0xA] -> bytes [0xAA,0xAA] = [170,170] |
| 127 | + # Row 3: FP4 [0.0,0.0,0.0,0.0] -> nibbles [0x0,0x0,0x0,0x0] -> bytes [0x00,0x00] = [0,0] |
| 128 | + w_q = torch.tensor( |
| 129 | + [[34, 34], [68, 68], [170, 170], [0, 0]], |
| 130 | + dtype=torch.uint8, |
| 131 | + device=device, |
| 132 | + ) |
| 133 | + scales = torch.full((N, K // group_size), 0.5, device=device, dtype=torch.float16) |
| 134 | + y = torch.empty(M, N, device=device, dtype=torch.float16) |
| 135 | + |
| 136 | + return { |
| 137 | + "x": x, |
| 138 | + "w_q": w_q, |
| 139 | + "scales": scales, |
| 140 | + "y": y, |
| 141 | + "M": M, |
| 142 | + "N": N, |
| 143 | + "K": K, |
| 144 | + "group_size": group_size, |
| 145 | + } |
| 146 | + |
| 147 | + def generate_functional_test(self) -> List[Dict[str, Any]]: |
| 148 | + torch.manual_seed(42) |
| 149 | + tests = [] |
| 150 | + |
| 151 | + # Edge cases with tiny shapes. |
| 152 | + tests.append(self._make_test_case(1, 2, 4, 2, zero_x=True)) |
| 153 | + tests.append(self._make_test_case(2, 4, 4, 2)) |
| 154 | + tests.append(self._make_test_case(3, 5, 8, 4)) |
| 155 | + |
| 156 | + # Power-of-2 shapes. |
| 157 | + tests.append(self._make_test_case(16, 16, 32, 16)) |
| 158 | + tests.append(self._make_test_case(32, 64, 64, 32)) |
| 159 | + tests.append(self._make_test_case(128, 128, 256, 32)) |
| 160 | + |
| 161 | + # Non-power-of-2 shapes. |
| 162 | + tests.append(self._make_test_case(30, 50, 64, 32)) |
| 163 | + tests.append(self._make_test_case(100, 200, 128, 32)) |
| 164 | + tests.append(self._make_test_case(255, 100, 128, 32)) |
| 165 | + |
| 166 | + # Realistic LLM inference shape. |
| 167 | + tests.append(self._make_test_case(512, 1024, 1024, 32)) |
| 168 | + |
| 169 | + return tests |
| 170 | + |
| 171 | + def generate_performance_test(self) -> Dict[str, Any]: |
| 172 | + torch.manual_seed(0) |
| 173 | + # Matches the FP4 matmul shapes reported in AutoKernel community results. |
| 174 | + return self._make_test_case(2048, 8192, 3072, 32) |
0 commit comments