Skip to content

Commit 33b83a3

Browse files
Add challenge 74: Layer Normalization (Medium)
Layer normalization is a core building block of transformer architectures (BERT, GPT, LLaMA). Unlike batch normalization, it normalizes across the feature dimension per sample, requiring efficient two-pass reductions (mean then variance) with shared memory — a non-trivial GPU programming challenge. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 0578315 commit 33b83a3

8 files changed

Lines changed: 370 additions & 0 deletions

File tree

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
<p>
2+
Implement the forward pass of layer normalization for a 2D input tensor. Given an input tensor of shape [N, C] where N is the batch size and C is the number of features, normalize each sample independently across its C features, then apply learnable scale (<code>weight</code>) and shift (<code>bias</code>) parameters. Layer normalization is a core building block of transformer architectures.
3+
</p>
4+
5+
<p>
6+
For each sample \(i\), layer normalization computes:
7+
\[
8+
\begin{align}
9+
\mu_i &= \frac{1}{C} \sum_{j=0}^{C-1} x_{i,j} \\
10+
\sigma_i^2 &= \frac{1}{C} \sum_{j=0}^{C-1} (x_{i,j} - \mu_i)^2 \\
11+
y_{i,j} &= \text{weight}_j \cdot \frac{x_{i,j} - \mu_i}{\sqrt{\sigma_i^2 + \varepsilon}} + \text{bias}_j
12+
\end{align}
13+
\]
14+
</p>
15+
16+
<h2>Implementation Requirements</h2>
17+
<ul>
18+
<li>Use only native features (external libraries are not permitted)</li>
19+
<li>The <code>solve</code> function signature must remain unchanged</li>
20+
<li>The final result must be stored in the <code>output</code> tensor</li>
21+
</ul>
22+
23+
<h2>Example</h2>
24+
<p>
25+
Input:<br>
26+
\(\text{input}\) (N=2, C=4):
27+
\[
28+
\begin{bmatrix}
29+
1.0 & 2.0 & 3.0 & 4.0 \\
30+
-1.0 & 0.0 & 0.0 & 1.0
31+
\end{bmatrix}
32+
\]
33+
\(\text{weight}\):
34+
\[
35+
\begin{bmatrix}
36+
1.0 & 1.0 & 1.0 & 1.0
37+
\end{bmatrix}
38+
\]
39+
\(\text{bias}\):
40+
\[
41+
\begin{bmatrix}
42+
0.0 & 0.0 & 0.0 & 0.0
43+
\end{bmatrix}
44+
\]
45+
\(\varepsilon\) = 1e-5<br><br>
46+
Output:<br>
47+
\(\text{output}\) (N=2, C=4):
48+
\[
49+
\begin{bmatrix}
50+
-1.3416 & -0.4472 & 0.4472 & 1.3416 \\
51+
-1.4142 & 0.0 & 0.0 & 1.4142
52+
\end{bmatrix}
53+
\]
54+
</p>
55+
56+
<h2>Constraints</h2>
57+
<ul>
58+
<li>1 &le; <code>N</code> &le; 65,536</li>
59+
<li>1 &le; <code>C</code> &le; 4,096</li>
60+
<li><code>eps</code> = 1e-5</li>
61+
<li>Input values are in the range [-100.0, 100.0]</li>
62+
<li>Weight values are in the range [0.1, 10.0]</li>
63+
<li>Bias values are in the range [-10.0, 10.0]</li>
64+
<li>Performance is measured with <code>N</code> = 65,536, <code>C</code> = 512</li>
65+
</ul>
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
import ctypes
2+
from typing import Any, Dict, List
3+
4+
import torch
5+
from core.challenge_base import ChallengeBase
6+
7+
8+
class Challenge(ChallengeBase):
9+
def __init__(self):
10+
super().__init__(
11+
name="Layer Normalization", atol=1e-04, rtol=1e-04, num_gpus=1, access_tier="free"
12+
)
13+
14+
def reference_impl(
15+
self,
16+
input: torch.Tensor,
17+
weight: torch.Tensor,
18+
bias: torch.Tensor,
19+
output: torch.Tensor,
20+
N: int,
21+
C: int,
22+
eps: float,
23+
):
24+
assert input.shape == output.shape == (N, C)
25+
assert weight.shape == bias.shape == (C,)
26+
assert input.dtype == weight.dtype == bias.dtype == output.dtype
27+
assert input.device == weight.device == bias.device == output.device
28+
assert str(input.device).startswith("cuda")
29+
30+
mean = input.mean(dim=1, keepdim=True)
31+
var = input.var(dim=1, keepdim=True, unbiased=False)
32+
normalized = (input - mean) / torch.sqrt(var + eps)
33+
output.copy_(weight * normalized + bias)
34+
35+
def get_solve_signature(self) -> Dict[str, tuple]:
36+
return {
37+
"input": (ctypes.POINTER(ctypes.c_float), "in"),
38+
"weight": (ctypes.POINTER(ctypes.c_float), "in"),
39+
"bias": (ctypes.POINTER(ctypes.c_float), "in"),
40+
"output": (ctypes.POINTER(ctypes.c_float), "out"),
41+
"N": (ctypes.c_int, "in"),
42+
"C": (ctypes.c_int, "in"),
43+
"eps": (ctypes.c_float, "in"),
44+
}
45+
46+
def generate_example_test(self) -> Dict[str, Any]:
47+
dtype = torch.float32
48+
N, C = 2, 4
49+
input = torch.tensor(
50+
[[1.0, 2.0, 3.0, 4.0], [-1.0, 0.0, 0.0, 1.0]], device="cuda", dtype=dtype
51+
)
52+
weight = torch.ones(C, device="cuda", dtype=dtype)
53+
bias = torch.zeros(C, device="cuda", dtype=dtype)
54+
output = torch.empty((N, C), device="cuda", dtype=dtype)
55+
eps = 1e-5
56+
return {
57+
"input": input,
58+
"weight": weight,
59+
"bias": bias,
60+
"output": output,
61+
"N": N,
62+
"C": C,
63+
"eps": eps,
64+
}
65+
66+
def generate_functional_test(self) -> List[Dict[str, Any]]:
67+
dtype = torch.float32
68+
tests = []
69+
70+
# edge: single element per row
71+
N, C = 1, 1
72+
tests.append(
73+
{
74+
"input": torch.tensor([[3.0]], device="cuda", dtype=dtype),
75+
"weight": torch.tensor([1.0], device="cuda", dtype=dtype),
76+
"bias": torch.tensor([0.5], device="cuda", dtype=dtype),
77+
"output": torch.empty((N, C), device="cuda", dtype=dtype),
78+
"N": N,
79+
"C": C,
80+
"eps": 1e-5,
81+
}
82+
)
83+
84+
# edge: 2x2, all zeros
85+
N, C = 2, 2
86+
tests.append(
87+
{
88+
"input": torch.zeros((N, C), device="cuda", dtype=dtype),
89+
"weight": torch.ones(C, device="cuda", dtype=dtype),
90+
"bias": torch.zeros(C, device="cuda", dtype=dtype),
91+
"output": torch.empty((N, C), device="cuda", dtype=dtype),
92+
"N": N,
93+
"C": C,
94+
"eps": 1e-5,
95+
}
96+
)
97+
98+
# edge: 4x4, negative values
99+
N, C = 4, 4
100+
tests.append(
101+
{
102+
"input": torch.tensor(
103+
[
104+
[-1.0, -2.0, -3.0, -4.0],
105+
[1.0, 2.0, 3.0, 4.0],
106+
[0.0, 0.0, 0.0, 0.0],
107+
[-2.0, 0.0, 2.0, 4.0],
108+
],
109+
device="cuda",
110+
dtype=dtype,
111+
),
112+
"weight": torch.tensor([1.0, 2.0, 1.0, 0.5], device="cuda", dtype=dtype),
113+
"bias": torch.tensor([0.0, 0.0, 1.0, -1.0], device="cuda", dtype=dtype),
114+
"output": torch.empty((N, C), device="cuda", dtype=dtype),
115+
"N": N,
116+
"C": C,
117+
"eps": 1e-5,
118+
}
119+
)
120+
121+
# power-of-2: 8x16
122+
N, C = 8, 16
123+
tests.append(
124+
{
125+
"input": torch.empty((N, C), device="cuda", dtype=dtype).uniform_(-5.0, 5.0),
126+
"weight": torch.empty(C, device="cuda", dtype=dtype).uniform_(0.5, 2.0),
127+
"bias": torch.empty(C, device="cuda", dtype=dtype).uniform_(-1.0, 1.0),
128+
"output": torch.empty((N, C), device="cuda", dtype=dtype),
129+
"N": N,
130+
"C": C,
131+
"eps": 1e-5,
132+
}
133+
)
134+
135+
# power-of-2: 32x64
136+
N, C = 32, 64
137+
tests.append(
138+
{
139+
"input": torch.empty((N, C), device="cuda", dtype=dtype).uniform_(-10.0, 10.0),
140+
"weight": torch.empty(C, device="cuda", dtype=dtype).uniform_(0.5, 2.0),
141+
"bias": torch.empty(C, device="cuda", dtype=dtype).uniform_(-2.0, 2.0),
142+
"output": torch.empty((N, C), device="cuda", dtype=dtype),
143+
"N": N,
144+
"C": C,
145+
"eps": 1e-5,
146+
}
147+
)
148+
149+
# power-of-2: 128x256
150+
N, C = 128, 256
151+
tests.append(
152+
{
153+
"input": torch.empty((N, C), device="cuda", dtype=dtype).uniform_(-10.0, 10.0),
154+
"weight": torch.empty(C, device="cuda", dtype=dtype).uniform_(0.5, 2.0),
155+
"bias": torch.empty(C, device="cuda", dtype=dtype).uniform_(-2.0, 2.0),
156+
"output": torch.empty((N, C), device="cuda", dtype=dtype),
157+
"N": N,
158+
"C": C,
159+
"eps": 1e-5,
160+
}
161+
)
162+
163+
# non-power-of-2: 7x30
164+
N, C = 7, 30
165+
tests.append(
166+
{
167+
"input": torch.empty((N, C), device="cuda", dtype=dtype).uniform_(-5.0, 5.0),
168+
"weight": torch.ones(C, device="cuda", dtype=dtype),
169+
"bias": torch.zeros(C, device="cuda", dtype=dtype),
170+
"output": torch.empty((N, C), device="cuda", dtype=dtype),
171+
"N": N,
172+
"C": C,
173+
"eps": 1e-5,
174+
}
175+
)
176+
177+
# non-power-of-2: 15x100
178+
N, C = 15, 100
179+
tests.append(
180+
{
181+
"input": torch.empty((N, C), device="cuda", dtype=dtype).uniform_(-100.0, 100.0),
182+
"weight": torch.empty(C, device="cuda", dtype=dtype).uniform_(0.1, 3.0),
183+
"bias": torch.empty(C, device="cuda", dtype=dtype).uniform_(-5.0, 5.0),
184+
"output": torch.empty((N, C), device="cuda", dtype=dtype),
185+
"N": N,
186+
"C": C,
187+
"eps": 1e-5,
188+
}
189+
)
190+
191+
# non-power-of-2: 25x255
192+
N, C = 25, 255
193+
tests.append(
194+
{
195+
"input": torch.empty((N, C), device="cuda", dtype=dtype).uniform_(-10.0, 10.0),
196+
"weight": torch.empty(C, device="cuda", dtype=dtype).uniform_(0.5, 2.0),
197+
"bias": torch.empty(C, device="cuda", dtype=dtype).uniform_(-1.0, 1.0),
198+
"output": torch.empty((N, C), device="cuda", dtype=dtype),
199+
"N": N,
200+
"C": C,
201+
"eps": 1e-5,
202+
}
203+
)
204+
205+
# realistic: 512x768 (BERT hidden size)
206+
N, C = 512, 768
207+
tests.append(
208+
{
209+
"input": torch.empty((N, C), device="cuda", dtype=dtype).uniform_(-5.0, 5.0),
210+
"weight": torch.empty(C, device="cuda", dtype=dtype).uniform_(0.5, 2.0),
211+
"bias": torch.empty(C, device="cuda", dtype=dtype).uniform_(-1.0, 1.0),
212+
"output": torch.empty((N, C), device="cuda", dtype=dtype),
213+
"N": N,
214+
"C": C,
215+
"eps": 1e-5,
216+
}
217+
)
218+
219+
return tests
220+
221+
def generate_performance_test(self) -> Dict[str, Any]:
222+
dtype = torch.float32
223+
N, C = 65536, 512
224+
return {
225+
"input": torch.empty((N, C), device="cuda", dtype=dtype).uniform_(-5.0, 10.0),
226+
"weight": torch.empty(C, device="cuda", dtype=dtype).uniform_(0.5, 2.0),
227+
"bias": torch.empty(C, device="cuda", dtype=dtype).uniform_(-1.0, 1.0),
228+
"output": torch.empty((N, C), device="cuda", dtype=dtype),
229+
"N": N,
230+
"C": C,
231+
"eps": 1e-5,
232+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include <cuda_runtime.h>
2+
3+
// input, weight, bias, output are device pointers
4+
extern "C" void solve(const float* input, const float* weight, const float* bias, float* output,
5+
int N, int C, float eps) {}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import cutlass
2+
import cutlass.cute as cute
3+
4+
5+
# input, weight, bias, output are tensors on the GPU
6+
@cute.jit
7+
def solve(
8+
input: cute.Tensor,
9+
weight: cute.Tensor,
10+
bias: cute.Tensor,
11+
output: cute.Tensor,
12+
N: cute.Int32,
13+
C: cute.Int32,
14+
eps: cute.Float32,
15+
):
16+
pass
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
5+
# input, weight, bias are tensors on the GPU
6+
@jax.jit
7+
def solve(
8+
input: jax.Array, weight: jax.Array, bias: jax.Array, N: int, C: int, eps: float
9+
) -> jax.Array:
10+
# return output tensor directly
11+
pass
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from gpu.host import DeviceContext
2+
from gpu.id import block_dim, block_idx, thread_idx
3+
from memory import UnsafePointer
4+
from math import ceildiv
5+
6+
# input, weight, bias, output are device pointers
7+
@export
8+
def solve(input: UnsafePointer[Float32], weight: UnsafePointer[Float32],
9+
bias: UnsafePointer[Float32], output: UnsafePointer[Float32],
10+
N: Int32, C: Int32, eps: Float32):
11+
pass
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
3+
4+
# input, weight, bias, output are tensors on the GPU
5+
def solve(
6+
input: torch.Tensor,
7+
weight: torch.Tensor,
8+
bias: torch.Tensor,
9+
output: torch.Tensor,
10+
N: int,
11+
C: int,
12+
eps: float,
13+
):
14+
pass
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
# input, weight, bias, output are tensors on the GPU
7+
def solve(
8+
input: torch.Tensor,
9+
weight: torch.Tensor,
10+
bias: torch.Tensor,
11+
output: torch.Tensor,
12+
N: int,
13+
C: int,
14+
eps: float,
15+
):
16+
pass

0 commit comments

Comments
 (0)