Skip to content

Add FP4 matmul hard challenge#241

Open
kunal-mansukhani wants to merge 2 commits intomainfrom
add-fp4-matmul-challenge
Open

Add FP4 matmul hard challenge#241
kunal-mansukhani wants to merge 2 commits intomainfrom
add-fp4-matmul-challenge

Conversation

@kunal-mansukhani
Copy link
Copy Markdown
Contributor

Summary

  • New hard challenge 86_fp4_matmul: weight-only FP4 E2M1 quantized matmul (W4A16) with group-wise FP16 scales, the kernel powering low-precision LLM inference on Hopper/Blackwell.
  • Two FP4 values are packed per uint8 byte (high/low nibble); each contiguous block of group_size weights along K shares one FP16 scale.
  • Performance shape M=2048, N=8192, K=3072, group_size=32 mirrors a row from the AutoKernel paper's FP4 matmul results table.
  • All 6 framework starters (CUDA, Triton, PyTorch, JAX, CuTe, Mojo) included; hard-tier parameter description comments.

Test plan

  • pre-commit run --all-files on new files: black, isort, flake8, clang-format, mojo format all pass
  • challenge.py imports cleanly; all 6 required methods present
  • reference_impl verified numerically on the example: y = [[1,2,-1,0],[1,2,-1,0]] matches the HTML example
  • Functional tests: 10 cases covering edge (1-4), powers-of-2, non-powers-of-2, realistic LLM shape, zero inputs
  • Run via scripts/run_challenge.py --language cuda --action run on Tesla T4

🤖 Generated with Claude Code

Weight-only FP4 E2M1 quantized matmul (W4A16) with group-wise FP16
scales, the kernel powering low-precision LLM inference on Hopper and
Blackwell. Two FP4 values are packed per uint8 byte; each contiguous
block of group_size weights along K shares one scale.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Restructures the challenge so a submission directly verifies AutoKernel's
FP4 matmul claim (Table 5 of the paper): both operands are packed FP4
E2M1 with E4M3 per-block scales and a per-tensor FP32 alpha, matching
the NVFP4 layout used by CUTLASS and qutlass. Previous revision was
W4A16 weight-only quant, which cannot reach the TF/s regime the paper
reports because x was still FP16.

Key changes:
- Both x and w are packed FP4 uint8 (nibbles); block size = 16.
- Scales are raw E4M3 bytes (torch.float8_e4m3fn bit patterns).
- Reference is a pure FP32 dequant + matmul oracle.
- Performance shape (M=2048, N=18432, K=3072) taken verbatim from the
  Triton vs CUTLASS row in Table 5 so TF/s is directly comparable.
- Tolerances loosened to atol=0.1, rtol=0.05 to admit FP16 accumulation
  used by tensor-core paths.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant