Skip to content

Commit c1b7afe

Browse files
Add challenge 92: Decaying Causal Attention (Medium)
Implements the core computation of the Retention mechanism (RetNet): causal unnormalized attention with geometric decay weights. Each position n attends to all past positions m <= n with weight gamma^(n-m), requiring solvers to reason about triangular memory access patterns, on-the-fly decay factor computation, and tiled accumulation strategies. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent ca891d3 commit c1b7afe

8 files changed

Lines changed: 394 additions & 0 deletions

File tree

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
<p>
2+
Implement decaying causal attention. Given query matrix <code>Q</code>, key matrix <code>K</code>,
3+
and value matrix <code>V</code>, each of shape <code>seq_len &times; d_model</code>, and a scalar
4+
decay factor <code>gamma</code> &isin; (0,&nbsp;1], compute the unnormalized causal attention output
5+
where position <code>n</code> attends to all past positions <code>m &le; n</code> with weight
6+
<code>gamma<sup>n&minus;m</sup></code>:
7+
</p>
8+
<p>
9+
\[
10+
\text{output}[n] = \sum_{m=0}^{n} \gamma^{n-m} \cdot \frac{Q[n] \cdot K[m]}{\sqrt{d_{\text{model}}}} \cdot V[m]
11+
\]
12+
</p>
13+
<p>
14+
Unlike standard softmax attention, there is no normalization — the weights decay geometrically from
15+
the current position backward. This is the parallel form of the Retention mechanism (RetNet), used
16+
as a recurrence-friendly alternative to attention in sequence models.
17+
</p>
18+
19+
<svg width="680" height="215" viewBox="0 0 680 215" xmlns="http://www.w3.org/2000/svg"
20+
style="display:block; margin:20px auto;">
21+
<rect width="680" height="215" fill="#222" rx="8"/>
22+
23+
<!-- Section title: decay mask -->
24+
<text x="148" y="24" text-anchor="middle" fill="#ccc" font-size="12" font-family="monospace">Causal Decay Mask D[n,m] = &#947;^(n&#8722;m)</text>
25+
26+
<!-- Column headers m=0..3 -->
27+
<text x="80" y="43" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">m=0</text>
28+
<text x="125" y="43" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">m=1</text>
29+
<text x="170" y="43" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">m=2</text>
30+
<text x="215" y="43" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">m=3</text>
31+
32+
<!-- Row labels n=0..3 -->
33+
<text x="42" y="72" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">n=0</text>
34+
<text x="42" y="112" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">n=1</text>
35+
<text x="42" y="152" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">n=2</text>
36+
<text x="42" y="192" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">n=3</text>
37+
38+
<!-- Row 0 -->
39+
<rect x="58" y="53" width="44" height="36" fill="#1a4a8a" stroke="#333" stroke-width="1"/>
40+
<text x="80" y="76" text-anchor="middle" fill="#4a9eff" font-size="11" font-family="monospace">1</text>
41+
<rect x="103" y="53" width="44" height="36" fill="#161616" stroke="#333" stroke-width="1"/>
42+
<rect x="148" y="53" width="44" height="36" fill="#161616" stroke="#333" stroke-width="1"/>
43+
<rect x="193" y="53" width="44" height="36" fill="#161616" stroke="#333" stroke-width="1"/>
44+
45+
<!-- Row 1 -->
46+
<rect x="58" y="91" width="44" height="36" fill="#143a6a" stroke="#333" stroke-width="1"/>
47+
<text x="80" y="114" text-anchor="middle" fill="#3a8ee0" font-size="11" font-family="monospace">&#947;</text>
48+
<rect x="103" y="91" width="44" height="36" fill="#1a4a8a" stroke="#333" stroke-width="1"/>
49+
<text x="125" y="114" text-anchor="middle" fill="#4a9eff" font-size="11" font-family="monospace">1</text>
50+
<rect x="148" y="91" width="44" height="36" fill="#161616" stroke="#333" stroke-width="1"/>
51+
<rect x="193" y="91" width="44" height="36" fill="#161616" stroke="#333" stroke-width="1"/>
52+
53+
<!-- Row 2 -->
54+
<rect x="58" y="129" width="44" height="36" fill="#0e2a54" stroke="#333" stroke-width="1"/>
55+
<text x="80" y="152" text-anchor="middle" fill="#2a7ec0" font-size="11" font-family="monospace">&#947;&#178;</text>
56+
<rect x="103" y="129" width="44" height="36" fill="#143a6a" stroke="#333" stroke-width="1"/>
57+
<text x="125" y="152" text-anchor="middle" fill="#3a8ee0" font-size="11" font-family="monospace">&#947;</text>
58+
<rect x="148" y="129" width="44" height="36" fill="#1a4a8a" stroke="#333" stroke-width="1"/>
59+
<text x="170" y="152" text-anchor="middle" fill="#4a9eff" font-size="11" font-family="monospace">1</text>
60+
<rect x="193" y="129" width="44" height="36" fill="#161616" stroke="#333" stroke-width="1"/>
61+
62+
<!-- Row 3 -->
63+
<rect x="58" y="167" width="44" height="36" fill="#081e3c" stroke="#333" stroke-width="1"/>
64+
<text x="80" y="190" text-anchor="middle" fill="#1a6ea0" font-size="11" font-family="monospace">&#947;&#179;</text>
65+
<rect x="103" y="167" width="44" height="36" fill="#0e2a54" stroke="#333" stroke-width="1"/>
66+
<text x="125" y="190" text-anchor="middle" fill="#2a7ec0" font-size="11" font-family="monospace">&#947;&#178;</text>
67+
<rect x="148" y="167" width="44" height="36" fill="#143a6a" stroke="#333" stroke-width="1"/>
68+
<text x="170" y="190" text-anchor="middle" fill="#3a8ee0" font-size="11" font-family="monospace">&#947;</text>
69+
<rect x="193" y="167" width="44" height="36" fill="#1a4a8a" stroke="#333" stroke-width="1"/>
70+
<text x="215" y="190" text-anchor="middle" fill="#4a9eff" font-size="11" font-family="monospace">1</text>
71+
72+
<!-- Divider -->
73+
<line x1="265" y1="30" x2="265" y2="210" stroke="#444" stroke-width="1" stroke-dasharray="4,3"/>
74+
75+
<!-- Right side: computation flow -->
76+
<text x="472" y="24" text-anchor="middle" fill="#ccc" font-size="12" font-family="monospace">Computation</text>
77+
78+
<defs>
79+
<marker id="arr2" markerWidth="7" markerHeight="7" refX="5" refY="3" orient="auto">
80+
<path d="M0,0 L0,6 L7,3 Z" fill="#888"/>
81+
</marker>
82+
</defs>
83+
84+
<!-- Step boxes -->
85+
<rect x="280" y="48" width="100" height="32" rx="4" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.2"/>
86+
<text x="330" y="62" text-anchor="middle" fill="#ccc" font-size="10" font-family="monospace">Q [S, D]</text>
87+
<text x="330" y="74" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">query</text>
88+
89+
<rect x="280" y="90" width="100" height="32" rx="4" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.2"/>
90+
<text x="330" y="104" text-anchor="middle" fill="#ccc" font-size="10" font-family="monospace">K [S, D]</text>
91+
<text x="330" y="116" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">key</text>
92+
93+
<rect x="280" y="132" width="100" height="32" rx="4" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.2"/>
94+
<text x="330" y="146" text-anchor="middle" fill="#ccc" font-size="10" font-family="monospace">V [S, D]</text>
95+
<text x="330" y="158" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">value</text>
96+
97+
<!-- Arrow from Q and K to scores -->
98+
<line x1="380" y1="64" x2="412" y2="90" stroke="#888" stroke-width="1.2" marker-end="url(#arr2)"/>
99+
<line x1="380" y1="106" x2="412" y2="96" stroke="#888" stroke-width="1.2" marker-end="url(#arr2)"/>
100+
101+
<rect x="414" y="78" width="110" height="34" rx="4" fill="#1a2a3c" stroke="#7ec8a0" stroke-width="1.2"/>
102+
<text x="469" y="92" text-anchor="middle" fill="#7ec8a0" font-size="10" font-family="monospace">QK&#7488; / &#8730;D</text>
103+
<text x="469" y="105" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">attn scores [S,S]</text>
104+
105+
<!-- Arrow: multiply by decay mask -->
106+
<line x1="469" y1="112" x2="469" y2="128" stroke="#888" stroke-width="1.2" marker-end="url(#arr2)"/>
107+
<text x="505" y="124" fill="#cc88ff" font-size="9" font-family="monospace">&#8857; decay mask</text>
108+
109+
<rect x="414" y="130" width="110" height="34" rx="4" fill="#2a1a3c" stroke="#cc88ff" stroke-width="1.2"/>
110+
<text x="469" y="144" text-anchor="middle" fill="#cc88ff" font-size="10" font-family="monospace">weighted [S,S]</text>
111+
<text x="469" y="157" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">lower triangular</text>
112+
113+
<!-- Arrow from V and weighted to output -->
114+
<line x1="380" y1="148" x2="412" y2="148" stroke="#888" stroke-width="1.2" marker-end="url(#arr2)"/>
115+
<line x1="524" y1="147" x2="546" y2="147" stroke="#888" stroke-width="1.2" marker-end="url(#arr2)"/>
116+
<text x="535" y="140" fill="#888" font-size="9" font-family="monospace">@</text>
117+
118+
<rect x="548" y="131" width="110" height="34" rx="4" fill="#1a3a1c" stroke="#4aff88" stroke-width="1.2"/>
119+
<text x="603" y="145" text-anchor="middle" fill="#4aff88" font-size="10" font-family="monospace">output [S, D]</text>
120+
<text x="603" y="158" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">weighted @ V</text>
121+
</svg>
122+
123+
<h2>Implementation Requirements</h2>
124+
<ul>
125+
<li>Implement the <code>solve</code> function; do not change its signature.</li>
126+
<li>Do not use external libraries beyond those provided.</li>
127+
<li>Write the result into <code>output</code>.</li>
128+
</ul>
129+
130+
<h2>Example</h2>
131+
<p>Example 1 — with <code>seq_len</code> = 2, <code>d_model</code> = 4, <code>gamma</code> = 0.5:</p>
132+
<p>
133+
\[
134+
Q = \begin{bmatrix} 1 & 1 & 0 & 0 \\ 1 & 1 & 0 & 0 \end{bmatrix}, \quad
135+
K = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \end{bmatrix}, \quad
136+
V = \begin{bmatrix} 4 & 8 & 12 & 16 \\ 4 & 8 & 12 & 16 \end{bmatrix}
137+
\]
138+
</p>
139+
<p>
140+
Attention scores \(QK^\top / \sqrt{4}\):
141+
\[
142+
A = \begin{bmatrix} 0.5 & 0.5 \\ 0.5 & 0.5 \end{bmatrix}
143+
\]
144+
Causal decay mask \(D_{nm} = 0.5^{n-m}\) for \(n \ge m\), else \(0\):
145+
\[
146+
D = \begin{bmatrix} 1 & 0 \\ 0.5 & 1 \end{bmatrix}
147+
\]
148+
Weighted attention \(A \odot D\):
149+
\[
150+
\begin{bmatrix} 0.5 & 0 \\ 0.25 & 0.5 \end{bmatrix}
151+
\]
152+
Output \((A \odot D)\,V\):
153+
\[
154+
\text{output} = \begin{bmatrix} 2 & 4 & 6 & 8 \\ 3 & 6 & 9 & 12 \end{bmatrix}
155+
\]
156+
</p>
157+
158+
<h2>Constraints</h2>
159+
<ul>
160+
<li>1 &le; <code>seq_len</code> &le; 8,192</li>
161+
<li>1 &le; <code>d_model</code> &le; 256</li>
162+
<li>0 &lt; <code>gamma</code> &le; 1</li>
163+
<li>All tensors are <code>float32</code> on GPU.</li>
164+
<li>Performance is measured with <code>seq_len</code> = 4,096, <code>d_model</code> = 64</li>
165+
</ul>
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import ctypes
2+
import math
3+
from typing import Any, Dict, List
4+
5+
import torch
6+
from core.challenge_base import ChallengeBase
7+
8+
9+
class Challenge(ChallengeBase):
10+
def __init__(self):
11+
super().__init__(
12+
name="Decaying Causal Attention",
13+
atol=1e-03,
14+
rtol=1e-03,
15+
num_gpus=1,
16+
access_tier="free",
17+
)
18+
19+
def reference_impl(
20+
self,
21+
Q: torch.Tensor,
22+
K: torch.Tensor,
23+
V: torch.Tensor,
24+
output: torch.Tensor,
25+
seq_len: int,
26+
d_model: int,
27+
gamma: float,
28+
):
29+
assert Q.shape == (seq_len, d_model)
30+
assert K.shape == (seq_len, d_model)
31+
assert V.shape == (seq_len, d_model)
32+
assert output.shape == (seq_len, d_model)
33+
assert Q.dtype == K.dtype == V.dtype == output.dtype == torch.float32
34+
assert Q.device.type == "cuda"
35+
assert K.device.type == "cuda"
36+
assert V.device.type == "cuda"
37+
assert output.device.type == "cuda"
38+
39+
scale = math.sqrt(d_model)
40+
positions = torch.arange(seq_len, device=Q.device, dtype=Q.dtype)
41+
# distances[n, m] = n - m; negative means m is in the future relative to n
42+
distances = positions.unsqueeze(1) - positions.unsqueeze(0)
43+
# causal: zero out future positions; clamp avoids overflow in gamma**negative
44+
causal = (distances >= 0).to(Q.dtype)
45+
decay_mask = torch.pow(gamma, distances.clamp(min=0)) * causal
46+
attn = torch.matmul(Q, K.T) / scale
47+
output.copy_(torch.matmul(attn * decay_mask, V))
48+
49+
def get_solve_signature(self) -> Dict[str, tuple]:
50+
return {
51+
"Q": (ctypes.POINTER(ctypes.c_float), "in"),
52+
"K": (ctypes.POINTER(ctypes.c_float), "in"),
53+
"V": (ctypes.POINTER(ctypes.c_float), "in"),
54+
"output": (ctypes.POINTER(ctypes.c_float), "out"),
55+
"seq_len": (ctypes.c_int, "in"),
56+
"d_model": (ctypes.c_int, "in"),
57+
"gamma": (ctypes.c_float, "in"),
58+
}
59+
60+
def generate_example_test(self) -> Dict[str, Any]:
61+
dtype = torch.float32
62+
device = "cuda"
63+
# Orthogonal K rows → QK^T / sqrt(4) = [[0.5, 0.5], [0.5, 0.5]].
64+
# With gamma=0.5 decay mask [[1, 0], [0.5, 1]], weighted attn = [[0.5, 0], [0.25, 0.5]].
65+
# Output row 0 = 0.5 * V[0]; row 1 = 0.25 * V[0] + 0.5 * V[1] = [3, 6, 9, 12].
66+
Q = torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]], device=device, dtype=dtype)
67+
K = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], device=device, dtype=dtype)
68+
V = torch.tensor(
69+
[[4.0, 8.0, 12.0, 16.0], [4.0, 8.0, 12.0, 16.0]], device=device, dtype=dtype
70+
)
71+
output = torch.zeros(2, 4, device=device, dtype=dtype)
72+
return {"Q": Q, "K": K, "V": V, "output": output, "seq_len": 2, "d_model": 4, "gamma": 0.5}
73+
74+
def _make_test_case(
75+
self,
76+
seq_len: int,
77+
d_model: int,
78+
gamma: float = 0.9,
79+
zero_qk: bool = False,
80+
negative: bool = False,
81+
) -> Dict[str, Any]:
82+
dtype = torch.float32
83+
device = "cuda"
84+
if zero_qk:
85+
Q = torch.zeros(seq_len, d_model, device=device, dtype=dtype)
86+
K = torch.zeros(seq_len, d_model, device=device, dtype=dtype)
87+
V = torch.randn(seq_len, d_model, device=device, dtype=dtype)
88+
elif negative:
89+
Q = torch.randn(seq_len, d_model, device=device, dtype=dtype).neg()
90+
K = torch.randn(seq_len, d_model, device=device, dtype=dtype).neg()
91+
V = torch.randn(seq_len, d_model, device=device, dtype=dtype).neg()
92+
else:
93+
Q = torch.randn(seq_len, d_model, device=device, dtype=dtype)
94+
K = torch.randn(seq_len, d_model, device=device, dtype=dtype)
95+
V = torch.randn(seq_len, d_model, device=device, dtype=dtype)
96+
output = torch.zeros(seq_len, d_model, device=device, dtype=dtype)
97+
return {
98+
"Q": Q,
99+
"K": K,
100+
"V": V,
101+
"output": output,
102+
"seq_len": seq_len,
103+
"d_model": d_model,
104+
"gamma": gamma,
105+
}
106+
107+
def generate_functional_test(self) -> List[Dict[str, Any]]:
108+
torch.manual_seed(42)
109+
tests = []
110+
111+
# Edge: single token (only self-attention possible)
112+
tests.append(self._make_test_case(1, 4, gamma=0.9))
113+
114+
# Edge: two tokens (matches example structure)
115+
tests.append(self._make_test_case(2, 4, gamma=0.5))
116+
117+
# Edge: gamma=1.0 — no decay, equal weight to all past positions
118+
tests.append(self._make_test_case(4, 8, gamma=1.0))
119+
120+
# Edge: small gamma — very sharp recency bias
121+
tests.append(self._make_test_case(4, 8, gamma=0.1))
122+
123+
# Zero Q and K: all attention scores are zero → output must be all zeros
124+
tests.append(self._make_test_case(8, 16, gamma=0.9, zero_qk=True))
125+
126+
# All-negative Q, K, V
127+
tests.append(self._make_test_case(16, 16, gamma=0.8, negative=True))
128+
129+
# Power-of-2 sequence length
130+
tests.append(self._make_test_case(32, 32, gamma=0.9))
131+
132+
# Power-of-2, larger
133+
tests.append(self._make_test_case(64, 64, gamma=0.8))
134+
135+
# Non-power-of-2 sequence length
136+
tests.append(self._make_test_case(30, 32, gamma=0.95))
137+
138+
# Non-power-of-2, larger realistic size
139+
tests.append(self._make_test_case(100, 64, gamma=0.9))
140+
141+
return tests
142+
143+
def generate_performance_test(self) -> Dict[str, Any]:
144+
torch.manual_seed(0)
145+
# Typical LLM head: seq_len=4096, head_dim=64
146+
return self._make_test_case(4096, 64, gamma=0.9)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include <cuda_runtime.h>
2+
3+
// Q, K, V, output are device pointers
4+
extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int seq_len,
5+
int d_model, float gamma) {}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import cutlass
2+
import cutlass.cute as cute
3+
4+
5+
# Q, K, V, output are tensors on the GPU
6+
@cute.jit
7+
def solve(
8+
Q: cute.Tensor,
9+
K: cute.Tensor,
10+
V: cute.Tensor,
11+
output: cute.Tensor,
12+
seq_len: cute.Int32,
13+
d_model: cute.Int32,
14+
gamma: cute.Float32,
15+
):
16+
pass
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
5+
# Q, K, V are tensors on GPU
6+
@jax.jit
7+
def solve(
8+
Q: jax.Array,
9+
K: jax.Array,
10+
V: jax.Array,
11+
seq_len: int,
12+
d_model: int,
13+
gamma: float,
14+
) -> jax.Array:
15+
# return output tensor directly
16+
pass
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from std.gpu.host import DeviceContext
2+
from std.memory import UnsafePointer
3+
4+
5+
# Q, K, V, output are device pointers
6+
@export
7+
def solve(
8+
Q: UnsafePointer[Float32, MutExternalOrigin],
9+
K: UnsafePointer[Float32, MutExternalOrigin],
10+
V: UnsafePointer[Float32, MutExternalOrigin],
11+
output: UnsafePointer[Float32, MutExternalOrigin],
12+
seq_len: Int32,
13+
d_model: Int32,
14+
gamma: Float32,
15+
) raises:
16+
pass
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
3+
4+
# Q, K, V, output are tensors on the GPU
5+
def solve(
6+
Q: torch.Tensor,
7+
K: torch.Tensor,
8+
V: torch.Tensor,
9+
output: torch.Tensor,
10+
seq_len: int,
11+
d_model: int,
12+
gamma: float,
13+
):
14+
pass

0 commit comments

Comments
 (0)