Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions challenges/medium/92_decaying_causal_attention/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
<p>
Implement decaying causal attention. Given query matrix <code>Q</code>, key matrix <code>K</code>,
and value matrix <code>V</code>, each of shape <code>seq_len &times; d_model</code>, and a scalar
decay factor <code>gamma</code> &isin; (0,&nbsp;1], compute the unnormalized causal attention output
where position <code>n</code> attends to all past positions <code>m &le; n</code> with weight
<code>gamma<sup>n&minus;m</sup></code>:
</p>
<p>
\[
\text{output}[n] = \sum_{m=0}^{n} \gamma^{n-m} \cdot \frac{Q[n] \cdot K[m]}{\sqrt{d_{\text{model}}}} \cdot V[m]
\]
</p>
<p>
Unlike standard softmax attention, there is no normalization — the weights decay geometrically from
the current position backward. This is the parallel form of the Retention mechanism (RetNet), used
as a recurrence-friendly alternative to attention in sequence models.
</p>

<svg width="680" height="215" viewBox="0 0 680 215" xmlns="http://www.w3.org/2000/svg"
style="display:block; margin:20px auto;">
<rect width="680" height="215" fill="#222" rx="8"/>

<!-- Section title: decay mask -->
<text x="148" y="24" text-anchor="middle" fill="#ccc" font-size="12" font-family="monospace">Causal Decay Mask D[n,m] = &#947;^(n&#8722;m)</text>

<!-- Column headers m=0..3 -->
<text x="80" y="43" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">m=0</text>
<text x="125" y="43" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">m=1</text>
<text x="170" y="43" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">m=2</text>
<text x="215" y="43" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">m=3</text>

<!-- Row labels n=0..3 -->
<text x="42" y="72" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">n=0</text>
<text x="42" y="112" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">n=1</text>
<text x="42" y="152" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">n=2</text>
<text x="42" y="192" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">n=3</text>

<!-- Row 0 -->
<rect x="58" y="53" width="44" height="36" fill="#1a4a8a" stroke="#333" stroke-width="1"/>
<text x="80" y="76" text-anchor="middle" fill="#4a9eff" font-size="11" font-family="monospace">1</text>
<rect x="103" y="53" width="44" height="36" fill="#161616" stroke="#333" stroke-width="1"/>
<rect x="148" y="53" width="44" height="36" fill="#161616" stroke="#333" stroke-width="1"/>
<rect x="193" y="53" width="44" height="36" fill="#161616" stroke="#333" stroke-width="1"/>

<!-- Row 1 -->
<rect x="58" y="91" width="44" height="36" fill="#143a6a" stroke="#333" stroke-width="1"/>
<text x="80" y="114" text-anchor="middle" fill="#3a8ee0" font-size="11" font-family="monospace">&#947;</text>
<rect x="103" y="91" width="44" height="36" fill="#1a4a8a" stroke="#333" stroke-width="1"/>
<text x="125" y="114" text-anchor="middle" fill="#4a9eff" font-size="11" font-family="monospace">1</text>
<rect x="148" y="91" width="44" height="36" fill="#161616" stroke="#333" stroke-width="1"/>
<rect x="193" y="91" width="44" height="36" fill="#161616" stroke="#333" stroke-width="1"/>

<!-- Row 2 -->
<rect x="58" y="129" width="44" height="36" fill="#0e2a54" stroke="#333" stroke-width="1"/>
<text x="80" y="152" text-anchor="middle" fill="#2a7ec0" font-size="11" font-family="monospace">&#947;&#178;</text>
<rect x="103" y="129" width="44" height="36" fill="#143a6a" stroke="#333" stroke-width="1"/>
<text x="125" y="152" text-anchor="middle" fill="#3a8ee0" font-size="11" font-family="monospace">&#947;</text>
<rect x="148" y="129" width="44" height="36" fill="#1a4a8a" stroke="#333" stroke-width="1"/>
<text x="170" y="152" text-anchor="middle" fill="#4a9eff" font-size="11" font-family="monospace">1</text>
<rect x="193" y="129" width="44" height="36" fill="#161616" stroke="#333" stroke-width="1"/>

<!-- Row 3 -->
<rect x="58" y="167" width="44" height="36" fill="#081e3c" stroke="#333" stroke-width="1"/>
<text x="80" y="190" text-anchor="middle" fill="#1a6ea0" font-size="11" font-family="monospace">&#947;&#179;</text>
<rect x="103" y="167" width="44" height="36" fill="#0e2a54" stroke="#333" stroke-width="1"/>
<text x="125" y="190" text-anchor="middle" fill="#2a7ec0" font-size="11" font-family="monospace">&#947;&#178;</text>
<rect x="148" y="167" width="44" height="36" fill="#143a6a" stroke="#333" stroke-width="1"/>
<text x="170" y="190" text-anchor="middle" fill="#3a8ee0" font-size="11" font-family="monospace">&#947;</text>
<rect x="193" y="167" width="44" height="36" fill="#1a4a8a" stroke="#333" stroke-width="1"/>
<text x="215" y="190" text-anchor="middle" fill="#4a9eff" font-size="11" font-family="monospace">1</text>

<!-- Divider -->
<line x1="265" y1="30" x2="265" y2="210" stroke="#444" stroke-width="1" stroke-dasharray="4,3"/>

<!-- Right side: computation flow -->
<text x="472" y="24" text-anchor="middle" fill="#ccc" font-size="12" font-family="monospace">Computation</text>

<defs>
<marker id="arr2" markerWidth="7" markerHeight="7" refX="5" refY="3" orient="auto">
<path d="M0,0 L0,6 L7,3 Z" fill="#888"/>
</marker>
</defs>

<!-- Step boxes -->
<rect x="280" y="48" width="100" height="32" rx="4" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.2"/>
<text x="330" y="62" text-anchor="middle" fill="#ccc" font-size="10" font-family="monospace">Q [S, D]</text>
<text x="330" y="74" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">query</text>

<rect x="280" y="90" width="100" height="32" rx="4" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.2"/>
<text x="330" y="104" text-anchor="middle" fill="#ccc" font-size="10" font-family="monospace">K [S, D]</text>
<text x="330" y="116" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">key</text>

<rect x="280" y="132" width="100" height="32" rx="4" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.2"/>
<text x="330" y="146" text-anchor="middle" fill="#ccc" font-size="10" font-family="monospace">V [S, D]</text>
<text x="330" y="158" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">value</text>

<!-- Arrow from Q and K to scores -->
<line x1="380" y1="64" x2="412" y2="90" stroke="#888" stroke-width="1.2" marker-end="url(#arr2)"/>
<line x1="380" y1="106" x2="412" y2="96" stroke="#888" stroke-width="1.2" marker-end="url(#arr2)"/>

<rect x="414" y="78" width="110" height="34" rx="4" fill="#1a2a3c" stroke="#7ec8a0" stroke-width="1.2"/>
<text x="469" y="92" text-anchor="middle" fill="#7ec8a0" font-size="10" font-family="monospace">QK&#7488; / &#8730;D</text>
<text x="469" y="105" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">attn scores [S,S]</text>

<!-- Arrow: multiply by decay mask -->
<line x1="469" y1="112" x2="469" y2="128" stroke="#888" stroke-width="1.2" marker-end="url(#arr2)"/>
<text x="505" y="124" fill="#cc88ff" font-size="9" font-family="monospace">&#8857; decay mask</text>

<rect x="414" y="130" width="110" height="34" rx="4" fill="#2a1a3c" stroke="#cc88ff" stroke-width="1.2"/>
<text x="469" y="144" text-anchor="middle" fill="#cc88ff" font-size="10" font-family="monospace">weighted [S,S]</text>
<text x="469" y="157" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">lower triangular</text>

<!-- Arrow from V and weighted to output -->
<line x1="380" y1="148" x2="412" y2="148" stroke="#888" stroke-width="1.2" marker-end="url(#arr2)"/>
<line x1="524" y1="147" x2="546" y2="147" stroke="#888" stroke-width="1.2" marker-end="url(#arr2)"/>
<text x="535" y="140" fill="#888" font-size="9" font-family="monospace">@</text>

<rect x="548" y="131" width="110" height="34" rx="4" fill="#1a3a1c" stroke="#4aff88" stroke-width="1.2"/>
<text x="603" y="145" text-anchor="middle" fill="#4aff88" font-size="10" font-family="monospace">output [S, D]</text>
<text x="603" y="158" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">weighted @ V</text>
</svg>

<h2>Implementation Requirements</h2>
<ul>
<li>Implement the <code>solve</code> function; do not change its signature.</li>
<li>Do not use external libraries beyond those provided.</li>
<li>Write the result into <code>output</code>.</li>
</ul>

<h2>Example</h2>
<p>Example 1 — with <code>seq_len</code> = 2, <code>d_model</code> = 4, <code>gamma</code> = 0.5:</p>
<p>
\[
Q = \begin{bmatrix} 1 & 1 & 0 & 0 \\ 1 & 1 & 0 & 0 \end{bmatrix}, \quad
K = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \end{bmatrix}, \quad
V = \begin{bmatrix} 4 & 8 & 12 & 16 \\ 4 & 8 & 12 & 16 \end{bmatrix}
\]
</p>
<p>
Attention scores \(QK^\top / \sqrt{4}\):
\[
A = \begin{bmatrix} 0.5 & 0.5 \\ 0.5 & 0.5 \end{bmatrix}
\]
Causal decay mask \(D_{nm} = 0.5^{n-m}\) for \(n \ge m\), else \(0\):
\[
D = \begin{bmatrix} 1 & 0 \\ 0.5 & 1 \end{bmatrix}
\]
Weighted attention \(A \odot D\):
\[
\begin{bmatrix} 0.5 & 0 \\ 0.25 & 0.5 \end{bmatrix}
\]
Output \((A \odot D)\,V\):
\[
\text{output} = \begin{bmatrix} 2 & 4 & 6 & 8 \\ 3 & 6 & 9 & 12 \end{bmatrix}
\]
</p>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>seq_len</code> &le; 8,192</li>
<li>1 &le; <code>d_model</code> &le; 256</li>
<li>0 &lt; <code>gamma</code> &le; 1</li>
<li>All tensors are <code>float32</code> on GPU.</li>
<li>Performance is measured with <code>seq_len</code> = 4,096, <code>d_model</code> = 64</li>
</ul>
146 changes: 146 additions & 0 deletions challenges/medium/92_decaying_causal_attention/challenge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import ctypes
import math
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase


class Challenge(ChallengeBase):
def __init__(self):
super().__init__(
name="Decaying Causal Attention",
atol=1e-03,
rtol=1e-03,
num_gpus=1,
access_tier="free",
)

def reference_impl(
self,
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
output: torch.Tensor,
seq_len: int,
d_model: int,
gamma: float,
):
assert Q.shape == (seq_len, d_model)
assert K.shape == (seq_len, d_model)
assert V.shape == (seq_len, d_model)
assert output.shape == (seq_len, d_model)
assert Q.dtype == K.dtype == V.dtype == output.dtype == torch.float32
assert Q.device.type == "cuda"
assert K.device.type == "cuda"
assert V.device.type == "cuda"
assert output.device.type == "cuda"

scale = math.sqrt(d_model)
positions = torch.arange(seq_len, device=Q.device, dtype=Q.dtype)
# distances[n, m] = n - m; negative means m is in the future relative to n
distances = positions.unsqueeze(1) - positions.unsqueeze(0)
# causal: zero out future positions; clamp avoids overflow in gamma**negative
causal = (distances >= 0).to(Q.dtype)
decay_mask = torch.pow(gamma, distances.clamp(min=0)) * causal
attn = torch.matmul(Q, K.T) / scale
output.copy_(torch.matmul(attn * decay_mask, V))

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"Q": (ctypes.POINTER(ctypes.c_float), "in"),
"K": (ctypes.POINTER(ctypes.c_float), "in"),
"V": (ctypes.POINTER(ctypes.c_float), "in"),
"output": (ctypes.POINTER(ctypes.c_float), "out"),
"seq_len": (ctypes.c_int, "in"),
"d_model": (ctypes.c_int, "in"),
"gamma": (ctypes.c_float, "in"),
}

def generate_example_test(self) -> Dict[str, Any]:
dtype = torch.float32
device = "cuda"
# Orthogonal K rows → QK^T / sqrt(4) = [[0.5, 0.5], [0.5, 0.5]].
# With gamma=0.5 decay mask [[1, 0], [0.5, 1]], weighted attn = [[0.5, 0], [0.25, 0.5]].
# Output row 0 = 0.5 * V[0]; row 1 = 0.25 * V[0] + 0.5 * V[1] = [3, 6, 9, 12].
Q = torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]], device=device, dtype=dtype)
K = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], device=device, dtype=dtype)
V = torch.tensor(
[[4.0, 8.0, 12.0, 16.0], [4.0, 8.0, 12.0, 16.0]], device=device, dtype=dtype
)
output = torch.zeros(2, 4, device=device, dtype=dtype)
return {"Q": Q, "K": K, "V": V, "output": output, "seq_len": 2, "d_model": 4, "gamma": 0.5}

def _make_test_case(
self,
seq_len: int,
d_model: int,
gamma: float = 0.9,
zero_qk: bool = False,
negative: bool = False,
) -> Dict[str, Any]:
dtype = torch.float32
device = "cuda"
if zero_qk:
Q = torch.zeros(seq_len, d_model, device=device, dtype=dtype)
K = torch.zeros(seq_len, d_model, device=device, dtype=dtype)
V = torch.randn(seq_len, d_model, device=device, dtype=dtype)
elif negative:
Q = torch.randn(seq_len, d_model, device=device, dtype=dtype).neg()
K = torch.randn(seq_len, d_model, device=device, dtype=dtype).neg()
V = torch.randn(seq_len, d_model, device=device, dtype=dtype).neg()
else:
Q = torch.randn(seq_len, d_model, device=device, dtype=dtype)
K = torch.randn(seq_len, d_model, device=device, dtype=dtype)
V = torch.randn(seq_len, d_model, device=device, dtype=dtype)
output = torch.zeros(seq_len, d_model, device=device, dtype=dtype)
return {
"Q": Q,
"K": K,
"V": V,
"output": output,
"seq_len": seq_len,
"d_model": d_model,
"gamma": gamma,
}

def generate_functional_test(self) -> List[Dict[str, Any]]:
torch.manual_seed(42)
tests = []

# Edge: single token (only self-attention possible)
tests.append(self._make_test_case(1, 4, gamma=0.9))

# Edge: two tokens (matches example structure)
tests.append(self._make_test_case(2, 4, gamma=0.5))

# Edge: gamma=1.0 — no decay, equal weight to all past positions
tests.append(self._make_test_case(4, 8, gamma=1.0))

# Edge: small gamma — very sharp recency bias
tests.append(self._make_test_case(4, 8, gamma=0.1))

# Zero Q and K: all attention scores are zero → output must be all zeros
tests.append(self._make_test_case(8, 16, gamma=0.9, zero_qk=True))

# All-negative Q, K, V
tests.append(self._make_test_case(16, 16, gamma=0.8, negative=True))

# Power-of-2 sequence length
tests.append(self._make_test_case(32, 32, gamma=0.9))

# Power-of-2, larger
tests.append(self._make_test_case(64, 64, gamma=0.8))

# Non-power-of-2 sequence length
tests.append(self._make_test_case(30, 32, gamma=0.95))

# Non-power-of-2, larger realistic size
tests.append(self._make_test_case(100, 64, gamma=0.9))

return tests

def generate_performance_test(self) -> Dict[str, Any]:
torch.manual_seed(0)
# Typical LLM head: seq_len=4096, head_dim=64
return self._make_test_case(4096, 64, gamma=0.9)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <cuda_runtime.h>

// Q, K, V, output are device pointers
extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int seq_len,
int d_model, float gamma) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import cutlass
import cutlass.cute as cute


# Q, K, V, output are tensors on the GPU
@cute.jit
def solve(
Q: cute.Tensor,
K: cute.Tensor,
V: cute.Tensor,
output: cute.Tensor,
seq_len: cute.Int32,
d_model: cute.Int32,
gamma: cute.Float32,
):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import jax
import jax.numpy as jnp


# Q, K, V are tensors on GPU
@jax.jit
def solve(
Q: jax.Array,
K: jax.Array,
V: jax.Array,
seq_len: int,
d_model: int,
gamma: float,
) -> jax.Array:
# return output tensor directly
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from std.gpu.host import DeviceContext
from std.memory import UnsafePointer


# Q, K, V, output are device pointers
@export
def solve(
Q: UnsafePointer[Float32, MutExternalOrigin],
K: UnsafePointer[Float32, MutExternalOrigin],
V: UnsafePointer[Float32, MutExternalOrigin],
output: UnsafePointer[Float32, MutExternalOrigin],
seq_len: Int32,
d_model: Int32,
gamma: Float32,
) raises:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch


# Q, K, V, output are tensors on the GPU
def solve(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
output: torch.Tensor,
seq_len: int,
d_model: int,
gamma: float,
):
pass
Loading
Loading