Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions challenges/medium/94_ssm_selective_scan/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
<p>
Implement the forward pass of a State Space Model (SSM) selective scan, the core operation in
Mamba-style sequence models. Given an input sequence <code>u</code>, time-step parameters
<code>delta</code>, state-transition matrix <code>A</code>, input projection <code>B</code>,
output projection <code>C</code>, and skip-connection weights <code>skip</code>, compute the
output sequence <code>y</code> in float32.
</p>

<svg width="700" height="180" viewBox="0 0 700 180" style="display:block; margin:20px auto;" xmlns="http://www.w3.org/2000/svg">
<rect width="700" height="180" fill="#222" rx="10"/>
<!-- SSM chain diagram -->
<!-- State boxes -->
<rect x="55" y="70" width="60" height="40" rx="6" fill="#1a3a5c" stroke="#4a90d9" stroke-width="1.5"/>
<text x="85" y="95" fill="#4a90d9" font-family="monospace" font-size="13" text-anchor="middle">h₀</text>

<rect x="195" y="70" width="60" height="40" rx="6" fill="#1a3a5c" stroke="#4a90d9" stroke-width="1.5"/>
<text x="225" y="95" fill="#4a90d9" font-family="monospace" font-size="13" text-anchor="middle">h₁</text>

<rect x="335" y="70" width="60" height="40" rx="6" fill="#1a3a5c" stroke="#4a90d9" stroke-width="1.5"/>
<text x="365" y="95" fill="#4a90d9" font-family="monospace" font-size="13" text-anchor="middle">h₂</text>

<rect x="475" y="70" width="60" height="40" rx="6" fill="#1a3a5c" stroke="#4a90d9" stroke-width="1.5"/>
<text x="505" y="95" fill="#4a90d9" font-family="monospace" font-size="13" text-anchor="middle">h₃</text>

<!-- Recurrence arrows -->
<line x1="115" y1="90" x2="193" y2="90" stroke="#4a90d9" stroke-width="1.5" marker-end="url(#arr)"/>
<line x1="255" y1="90" x2="333" y2="90" stroke="#4a90d9" stroke-width="1.5" marker-end="url(#arr)"/>
<line x1="395" y1="90" x2="473" y2="90" stroke="#4a90d9" stroke-width="1.5" marker-end="url(#arr)"/>
<text x="153" y="83" fill="#ccc" font-family="monospace" font-size="10" text-anchor="middle">Ā</text>
<text x="293" y="83" fill="#ccc" font-family="monospace" font-size="10" text-anchor="middle">Ā</text>
<text x="433" y="83" fill="#ccc" font-family="monospace" font-size="10" text-anchor="middle">Ā</text>

<!-- Input arrows (u into h) -->
<line x1="85" y1="155" x2="85" y2="112" stroke="#5cb85c" stroke-width="1.5" marker-end="url(#garr)"/>
<line x1="225" y1="155" x2="225" y2="112" stroke="#5cb85c" stroke-width="1.5" marker-end="url(#garr)"/>
<line x1="365" y1="155" x2="365" y2="112" stroke="#5cb85c" stroke-width="1.5" marker-end="url(#garr)"/>
<line x1="505" y1="155" x2="505" y2="112" stroke="#5cb85c" stroke-width="1.5" marker-end="url(#garr)"/>
<text x="85" y="168" fill="#5cb85c" font-family="monospace" font-size="11" text-anchor="middle">B̄u₀</text>
<text x="225" y="168" fill="#5cb85c" font-family="monospace" font-size="11" text-anchor="middle">B̄u₁</text>
<text x="365" y="168" fill="#5cb85c" font-family="monospace" font-size="11" text-anchor="middle">B̄u₂</text>
<text x="505" y="168" fill="#5cb85c" font-family="monospace" font-size="11" text-anchor="middle">B̄u₃</text>

<!-- Output arrows (h to y) -->
<line x1="85" y1="68" x2="85" y2="30" stroke="#e87c2e" stroke-width="1.5" marker-end="url(#oarr)"/>
<line x1="225" y1="68" x2="225" y2="30" stroke="#e87c2e" stroke-width="1.5" marker-end="url(#oarr)"/>
<line x1="365" y1="68" x2="365" y2="30" stroke="#e87c2e" stroke-width="1.5" marker-end="url(#oarr)"/>
<line x1="505" y1="68" x2="505" y2="30" stroke="#e87c2e" stroke-width="1.5" marker-end="url(#oarr)"/>
<text x="85" y="22" fill="#e87c2e" font-family="monospace" font-size="11" text-anchor="middle">y₀</text>
<text x="225" y="22" fill="#e87c2e" font-family="monospace" font-size="11" text-anchor="middle">y₁</text>
<text x="365" y="22" fill="#e87c2e" font-family="monospace" font-size="11" text-anchor="middle">y₂</text>
<text x="505" y="22" fill="#e87c2e" font-family="monospace" font-size="11" text-anchor="middle">y₃</text>

<!-- Continuation arrow -->
<line x1="535" y1="90" x2="590" y2="90" stroke="#4a90d9" stroke-width="1.5" stroke-dasharray="4,3" marker-end="url(#arr)"/>

<!-- Arrow markers -->
<defs>
<marker id="arr" markerWidth="8" markerHeight="6" refX="8" refY="3" orient="auto">
<polygon points="0 0, 8 3, 0 6" fill="#4a90d9"/>
</marker>
<marker id="garr" markerWidth="8" markerHeight="6" refX="8" refY="3" orient="auto">
<polygon points="0 0, 8 3, 0 6" fill="#5cb85c"/>
</marker>
<marker id="oarr" markerWidth="8" markerHeight="6" refX="8" refY="3" orient="auto">
<polygon points="0 0, 8 3, 0 6" fill="#e87c2e"/>
</marker>
</defs>
</svg>

<h2>Implementation Requirements</h2>
<p>
Implement the function <code>solve(u, delta, A, B, C, skip, y, batch, seq_len, d_model, d_state)</code>
with the signature unchanged. Do not use external libraries beyond the allowed framework.
Write the result into the pre-allocated output tensor <code>y</code>.
</p>
<p>
For each batch <code>b</code>, position <code>t</code>, and channel <code>d</code>, the computation is:
</p>
<p>
\[
\bar{A}_{b,t,d,n} = \exp(\Delta_{b,t,d} \cdot A_{d,n})
\]
\[
\bar{B}_{b,t,d,n} = \Delta_{b,t,d} \cdot B_{b,t,n}
\]
\[
h_{b,t,d,n} = \bar{A}_{b,t,d,n} \cdot h_{b,t-1,d,n} + \bar{B}_{b,t,d,n} \cdot u_{b,t,d}
\]
\[
y_{b,t,d} = \sum_{n} C_{b,t,n} \cdot h_{b,t,d,n} + \text{skip}_d \cdot u_{b,t,d}
\]
</p>
<p>
The initial hidden state \(h_{b,-1,d,n} = 0\) for all \(b, d, n\).
All channels <code>d</code> are independent: they share the same <code>B</code> and <code>C</code>
projections but have separate state-transition rows in <code>A</code>.
</p>

<h2>Example</h2>
<pre>
Input:
u = [[[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0]]] shape (1,4,2)
delta = [[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]] shape (1,4,2)
A = [[-0.5, -1.0], [-0.5, -1.0]] shape (2,2)
B = [[[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]]] shape (1,4,2)
C = [[[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]]] shape (1,4,2)
skip = [0.0, 0.0] shape (2,)
batch=1, seq_len=4, d_model=2, d_state=2

Derivation (delta=1 everywhere, so A_bar_dn = exp(A_dn)):
A_bar[d=0] = [exp(-0.5), exp(-1.0)] ≈ [0.607, 0.368]
A_bar[d=1] = [exp(-0.5), exp(-1.0)] ≈ [0.607, 0.368]

Hidden state h has shape (d_model=2, d_state=2); initial h = zeros.
t=0: h = [[1.000, 0.000], [0.000, 0.000]] → y[0,0] = [1.000, 0.000]
t=1: h = [[0.607, 0.000], [0.000, 1.000]] → y[0,1] = [0.000, 1.000]
t=2: h = [[1.368, 1.000], [1.000, 1.368]] → y[0,2] = [2.368, 2.368]
t=3: h = [[0.830, 0.368], [0.607, 0.503]] → y[0,3] = [0.599, 0.555]

Output:
y = [[[1.000, 0.000], [0.000, 1.000], [2.368, 2.368], [0.599, 0.555]]]
</pre>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>batch</code> &le; 16</li>
<li>1 &le; <code>seq_len</code> &le; 8,192</li>
<li>1 &le; <code>d_model</code> &le; 2,048</li>
<li>1 &le; <code>d_state</code> &le; 64</li>
<li>All entries of <code>delta</code> are positive</li>
<li>All entries of <code>A</code> are negative (ensuring <code>A_bar &isin; (0, 1)</code>)</li>
<li>All tensors are float32 on the GPU</li>
<li>Performance is measured with <code>batch</code> = 4, <code>seq_len</code> = 4,096, <code>d_model</code> = 512, <code>d_state</code> = 16</li>
</ul>
201 changes: 201 additions & 0 deletions challenges/medium/94_ssm_selective_scan/challenge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import ctypes
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase


class Challenge(ChallengeBase):
def __init__(self):
super().__init__(
name="SSM Selective Scan",
atol=1e-03,
rtol=1e-03,
num_gpus=1,
access_tier="free",
)

def reference_impl(
self,
u: torch.Tensor,
delta: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
skip: torch.Tensor,
y: torch.Tensor,
batch: int,
seq_len: int,
d_model: int,
d_state: int,
):
assert u.shape == (batch, seq_len, d_model)
assert delta.shape == (batch, seq_len, d_model)
assert A.shape == (d_model, d_state)
assert B.shape == (batch, seq_len, d_state)
assert C.shape == (batch, seq_len, d_state)
assert skip.shape == (d_model,)
assert y.shape == (batch, seq_len, d_model)
assert (
u.dtype == delta.dtype == A.dtype == B.dtype == C.dtype == skip.dtype == torch.float32
)
assert u.device.type == "cuda"
assert delta.device.type == "cuda"
assert A.device.type == "cuda"
assert B.device.type == "cuda"
assert C.device.type == "cuda"
assert skip.device.type == "cuda"
assert y.device.type == "cuda"

# Hidden state: (batch, d_model, d_state)
h = torch.zeros(batch, d_model, d_state, device=u.device, dtype=u.dtype)

for t in range(seq_len):
delta_t = delta[:, t, :] # (batch, d_model)
u_t = u[:, t, :] # (batch, d_model)

# Discretize: A_bar = exp(delta_t * A)
# delta_t: (batch, d_model) -> (batch, d_model, 1)
# A: (d_model, d_state) -> (1, d_model, d_state)
A_bar = torch.exp(delta_t.unsqueeze(-1) * A.unsqueeze(0)) # (batch, d_model, d_state)

# B_bar = delta_t * B_t
# B[:, t, :]: (batch, d_state) -> (batch, 1, d_state)
B_bar = delta_t.unsqueeze(-1) * B[:, t, :].unsqueeze(1) # (batch, d_model, d_state)

# State update: h = A_bar * h + B_bar * u_t
h = A_bar * h + B_bar * u_t.unsqueeze(-1) # (batch, d_model, d_state)

# Output: y_t = C_t @ h + skip * u_t
# C[:, t, :]: (batch, d_state) -> einsum with h (batch, d_model, d_state)
C_t = C[:, t, :] # (batch, d_state)
y_t = torch.einsum("bn,bdn->bd", C_t, h) + skip * u_t # (batch, d_model)
y[:, t, :] = y_t

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"u": (ctypes.POINTER(ctypes.c_float), "in"),
"delta": (ctypes.POINTER(ctypes.c_float), "in"),
"A": (ctypes.POINTER(ctypes.c_float), "in"),
"B": (ctypes.POINTER(ctypes.c_float), "in"),
"C": (ctypes.POINTER(ctypes.c_float), "in"),
"skip": (ctypes.POINTER(ctypes.c_float), "in"),
"y": (ctypes.POINTER(ctypes.c_float), "out"),
"batch": (ctypes.c_int, "in"),
"seq_len": (ctypes.c_int, "in"),
"d_model": (ctypes.c_int, "in"),
"d_state": (ctypes.c_int, "in"),
}

def _make_test_case(self, batch, seq_len, d_model, d_state, zero_u=False, zero_delta=False):
device = "cuda"
dtype = torch.float32
if zero_u:
u = torch.zeros(batch, seq_len, d_model, device=device, dtype=dtype)
else:
u = torch.randn(batch, seq_len, d_model, device=device, dtype=dtype)
if zero_delta:
delta = torch.zeros(batch, seq_len, d_model, device=device, dtype=dtype)
else:
# delta must be positive
delta = torch.rand(batch, seq_len, d_model, device=device, dtype=dtype) + 0.01
# A must be negative for stability (eigenvalues < 0)
A = -torch.rand(d_model, d_state, device=device, dtype=dtype) - 0.01
B = torch.randn(batch, seq_len, d_state, device=device, dtype=dtype)
C = torch.randn(batch, seq_len, d_state, device=device, dtype=dtype)
skip = torch.rand(d_model, device=device, dtype=dtype)
y = torch.empty(batch, seq_len, d_model, device=device, dtype=dtype)
return {
"u": u,
"delta": delta,
"A": A,
"B": B,
"C": C,
"skip": skip,
"y": y,
"batch": batch,
"seq_len": seq_len,
"d_model": d_model,
"d_state": d_state,
}

def generate_example_test(self) -> Dict[str, Any]:
torch.manual_seed(0)
device = "cuda"
dtype = torch.float32
batch, seq_len, d_model, d_state = 1, 4, 2, 2
u = torch.tensor(
[[[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0]]],
device=device,
dtype=dtype,
)
delta = torch.tensor(
[[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]],
device=device,
dtype=dtype,
)
A = torch.tensor([[-0.5, -1.0], [-0.5, -1.0]], device=device, dtype=dtype)
B = torch.tensor(
[[[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]]],
device=device,
dtype=dtype,
)
C = torch.tensor(
[[[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]]],
device=device,
dtype=dtype,
)
skip = torch.tensor([0.0, 0.0], device=device, dtype=dtype)
y = torch.empty(batch, seq_len, d_model, device=device, dtype=dtype)
return {
"u": u,
"delta": delta,
"A": A,
"B": B,
"C": C,
"skip": skip,
"y": y,
"batch": batch,
"seq_len": seq_len,
"d_model": d_model,
"d_state": d_state,
}

def generate_functional_test(self) -> List[Dict[str, Any]]:
torch.manual_seed(42)
tests = []

# Edge case: single token
tests.append(self._make_test_case(1, 1, 1, 4))

# Edge case: tiny dimensions
tests.append(self._make_test_case(1, 2, 2, 2))

# Edge case: zero input (output should be skip * 0 = 0)
tests.append(self._make_test_case(1, 4, 4, 4, zero_u=True))

# Edge case: zero delta (A_bar=1, B_bar=0, so state stays zero, output = skip * u)
tests.append(self._make_test_case(2, 4, 4, 4, zero_delta=True))

# Power-of-2 lengths
tests.append(self._make_test_case(2, 16, 8, 4))
tests.append(self._make_test_case(2, 64, 16, 8))

# Non-power-of-2
tests.append(self._make_test_case(2, 30, 12, 4))
tests.append(self._make_test_case(3, 100, 24, 8))

# Typical d_state=16 (common Mamba setting)
tests.append(self._make_test_case(2, 128, 32, 16))

# Realistic size
tests.append(self._make_test_case(4, 256, 64, 16))

return tests

def generate_performance_test(self) -> Dict[str, Any]:
torch.manual_seed(0)
# batch=4, seq_len=4096, d_model=512, d_state=16
# Memory: u+delta+y ~ 3 * 4*4096*512*4 = 96MB; A+B+C+skip small
# Total << 1GB, comfortably fits 5x in 16GB T4
return self._make_test_case(4, 4096, 512, 16)
7 changes: 7 additions & 0 deletions challenges/medium/94_ssm_selective_scan/starter/starter.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#include <cuda_runtime.h>
#include <math.h>

// u, delta, A, B, C, skip, y are device pointers
extern "C" void solve(const float* u, const float* delta, const float* A, const float* B,
const float* C, const float* skip, float* y, int batch, int seq_len,
int d_model, int d_state) {}
20 changes: 20 additions & 0 deletions challenges/medium/94_ssm_selective_scan/starter/starter.cute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import cutlass
import cutlass.cute as cute


# u, delta, A, B, C, skip, y are tensors on the GPU
@cute.jit
def solve(
u: cute.Tensor,
delta: cute.Tensor,
A: cute.Tensor,
B: cute.Tensor,
C: cute.Tensor,
skip: cute.Tensor,
y: cute.Tensor,
batch: cute.Uint32,
seq_len: cute.Uint32,
d_model: cute.Uint32,
d_state: cute.Uint32,
):
pass
20 changes: 20 additions & 0 deletions challenges/medium/94_ssm_selective_scan/starter/starter.jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import jax
import jax.numpy as jnp


# u, delta, A, B, C, skip are tensors on GPU
@jax.jit
def solve(
u: jax.Array,
delta: jax.Array,
A: jax.Array,
B: jax.Array,
C: jax.Array,
skip: jax.Array,
batch: int,
seq_len: int,
d_model: int,
d_state: int,
) -> jax.Array:
# return output tensor directly
pass
Loading
Loading