Skip to content

Commit 6a506e7

Browse files
claude[bot]github-actions[bot]claude
authored
Add challenge 78: 2D FFT (Medium) (#208)
Introduces the 2D Discrete Fourier Transform challenge, teaching GPU programmers to implement a row-column decomposition FFT. Key concepts: batched 1D FFTs, coalesced vs strided memory access patterns for row/column processing, and shared-memory butterfly kernels. Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent de854e1 commit 6a506e7

8 files changed

Lines changed: 196 additions & 0 deletions

File tree

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
<p>
2+
Compute the 2D Discrete Fourier Transform (2D DFT) of a complex-valued signal stored on the GPU.
3+
Given a 2D complex input signal of shape <code>M &times; N</code>, compute its 2D DFT spectrum
4+
using the row-column decomposition: apply a 1D DFT along each row, then a 1D DFT along each
5+
column of the result. All values are 32-bit floating point.
6+
</p>
7+
8+
<h2>Implementation Requirements</h2>
9+
<ul>
10+
<li>Use only native features (external libraries are not permitted)</li>
11+
<li>The <code>solve</code> function signature must remain unchanged</li>
12+
<li>The final result must be stored in <code>spectrum</code></li>
13+
<li>
14+
The input and output are stored as 1D arrays of interleaved real and imaginary parts in
15+
row-major order: element <code>x[m, n]</code> has its real part at index
16+
<code>2*(m*N + n)</code> and imaginary part at index <code>2*(m*N + n) + 1</code>
17+
</li>
18+
</ul>
19+
20+
<h2>Example</h2>
21+
<p>
22+
Input: <code>M</code> = 2, <code>N</code> = 2<br>
23+
Signal \(x[m, n]\) (real part):
24+
\[
25+
\begin{bmatrix}
26+
1.0 & 0.0 \\
27+
0.0 & 0.0
28+
\end{bmatrix}
29+
\]
30+
Signal \(x[m, n]\) (imaginary part):
31+
\[
32+
\begin{bmatrix}
33+
0.0 & 0.0 \\
34+
0.0 & 0.0
35+
\end{bmatrix}
36+
\]
37+
Output:<br>
38+
Spectrum \(X[k, l]\) (real part):
39+
\[
40+
\begin{bmatrix}
41+
1.0 & 1.0 \\
42+
1.0 & 1.0
43+
\end{bmatrix}
44+
\]
45+
Spectrum \(X[k, l]\) (imaginary part):
46+
\[
47+
\begin{bmatrix}
48+
0.0 & 0.0 \\
49+
0.0 & 0.0
50+
\end{bmatrix}
51+
\]
52+
</p>
53+
54+
<h2>Constraints</h2>
55+
<ul>
56+
<li>1 &le; <code>M</code>, <code>N</code> &le; 4096</li>
57+
<li>Signal values are 32-bit floating point (real and imaginary parts)</li>
58+
<li>Performance is measured with <code>M</code> = 2,048, <code>N</code> = 2,048</li>
59+
</ul>
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import ctypes
2+
from typing import Any, Dict, List
3+
4+
import torch
5+
from core.challenge_base import ChallengeBase
6+
7+
8+
class Challenge(ChallengeBase):
9+
def __init__(self):
10+
super().__init__(
11+
name="2D FFT",
12+
atol=1e-02,
13+
rtol=1e-02,
14+
num_gpus=1,
15+
access_tier="free",
16+
)
17+
18+
def reference_impl(self, signal: torch.Tensor, spectrum: torch.Tensor, M: int, N: int):
19+
assert signal.shape == (M * N * 2,)
20+
assert spectrum.shape == (M * N * 2,)
21+
assert signal.dtype == torch.float32
22+
assert spectrum.dtype == torch.float32
23+
assert signal.device == spectrum.device
24+
25+
sig_ri = signal.view(M, N, 2)
26+
sig_c = torch.complex(sig_ri[..., 0].contiguous(), sig_ri[..., 1].contiguous())
27+
spec_c = torch.fft.fft2(sig_c)
28+
spec_ri = torch.stack((spec_c.real, spec_c.imag), dim=-1).contiguous()
29+
spectrum.copy_(spec_ri.view(-1))
30+
31+
def get_solve_signature(self) -> Dict[str, tuple]:
32+
return {
33+
"signal": (ctypes.POINTER(ctypes.c_float), "in"),
34+
"spectrum": (ctypes.POINTER(ctypes.c_float), "out"),
35+
"M": (ctypes.c_int, "in"),
36+
"N": (ctypes.c_int, "in"),
37+
}
38+
39+
def generate_example_test(self) -> Dict[str, Any]:
40+
dtype = torch.float32
41+
M, N = 2, 2
42+
signal = torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], device="cuda", dtype=dtype)
43+
spectrum = torch.empty(M * N * 2, device="cuda", dtype=dtype)
44+
return {"signal": signal, "spectrum": spectrum, "M": M, "N": N}
45+
46+
def generate_functional_test(self) -> List[Dict[str, Any]]:
47+
dtype = torch.float32
48+
cases = []
49+
50+
def make_case(M, N, low=-1.0, high=1.0):
51+
signal = torch.empty(M * N * 2, device="cuda", dtype=dtype).uniform_(low, high)
52+
spectrum = torch.empty(M * N * 2, device="cuda", dtype=dtype)
53+
return {"signal": signal, "spectrum": spectrum, "M": M, "N": N}
54+
55+
def make_zero_case(M, N):
56+
signal = torch.zeros(M * N * 2, device="cuda", dtype=dtype)
57+
spectrum = torch.empty(M * N * 2, device="cuda", dtype=dtype)
58+
return {"signal": signal, "spectrum": spectrum, "M": M, "N": N}
59+
60+
def make_impulse_case(M, N):
61+
signal = torch.zeros(M * N * 2, device="cuda", dtype=dtype)
62+
signal[0] = 1.0
63+
spectrum = torch.empty(M * N * 2, device="cuda", dtype=dtype)
64+
return {"signal": signal, "spectrum": spectrum, "M": M, "N": N}
65+
66+
# Edge cases: small sizes
67+
cases.append(make_impulse_case(1, 1))
68+
cases.append(make_zero_case(2, 2))
69+
cases.append(make_case(1, 4))
70+
71+
# Power-of-2 sizes
72+
cases.append(make_case(16, 16))
73+
cases.append(make_case(32, 64))
74+
75+
# Non-power-of-2 sizes
76+
cases.append(make_case(3, 5))
77+
cases.append(make_case(30, 30))
78+
79+
# Mixed positive/negative values
80+
cases.append(make_case(100, 200, low=-5.0, high=5.0))
81+
82+
# Realistic sizes
83+
cases.append(make_case(256, 256))
84+
cases.append(make_case(512, 512))
85+
86+
return cases
87+
88+
def generate_performance_test(self) -> Dict[str, Any]:
89+
dtype = torch.float32
90+
M, N = 2048, 2048
91+
signal = torch.empty(M * N * 2, device="cuda", dtype=dtype).normal_(0.0, 1.0)
92+
spectrum = torch.empty(M * N * 2, device="cuda", dtype=dtype)
93+
return {"signal": signal, "spectrum": spectrum, "M": M, "N": N}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#include <cuda_runtime.h>
2+
3+
// signal, spectrum are device pointers
4+
extern "C" void solve(const float* signal, float* spectrum, int M, int N) {}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import cutlass
2+
import cutlass.cute as cute
3+
4+
5+
# signal, spectrum are tensors on the GPU
6+
@cute.jit
7+
def solve(signal: cute.Tensor, spectrum: cute.Tensor, M: cute.Int32, N: cute.Int32):
8+
pass
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
5+
# signal is a tensor on GPU
6+
@jax.jit
7+
def solve(signal: jax.Array, M: int, N: int) -> jax.Array:
8+
# return output tensor directly
9+
pass
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from gpu.host import DeviceContext
2+
from gpu.id import block_dim, block_idx, thread_idx
3+
from memory import UnsafePointer
4+
from math import ceildiv
5+
6+
# signal, spectrum are device pointers
7+
@export
8+
def solve(signal: UnsafePointer[Float32], spectrum: UnsafePointer[Float32], M: Int32, N: Int32):
9+
pass
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import torch
2+
3+
4+
# signal, spectrum are tensors on the GPU
5+
def solve(signal: torch.Tensor, spectrum: torch.Tensor, M: int, N: int):
6+
pass
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
# signal, spectrum are tensors on the GPU
7+
def solve(signal: torch.Tensor, spectrum: torch.Tensor, M: int, N: int):
8+
pass

0 commit comments

Comments
 (0)