Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions challenges/medium/91_moe_token_dispatch/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
<p>
In Mixture-of-Experts (MoE) neural networks, a router assigns each input token to one of
<code>E</code> expert networks. Before those experts can execute in parallel, tokens must be
<em>dispatched</em>: reorganized from their original sequence order into contiguous per-expert
buffers. Given <code>T</code> tokens of feature dimension <code>D</code> and their expert
assignments, fill each expert's slice of the output buffer with the tokens routed to that
expert, preserving the original token order within each expert's group.
</p>

<svg width="540" height="230" viewBox="0 0 540 230" style="display:block; margin:20px auto;" xmlns="http://www.w3.org/2000/svg">
<rect width="540" height="230" rx="8" fill="#222"/>

<!-- Title -->
<text x="270" y="22" text-anchor="middle" fill="#ccc" font-family="monospace" font-size="12">MoE Token Dispatch (T=4, E=2)</text>

<!-- Left: input tokens -->
<text x="80" y="48" text-anchor="middle" fill="#aaa" font-family="monospace" font-size="11">x [T, D]</text>

<!-- Token 0 → expert 0 (blue) -->
<rect x="20" y="55" width="120" height="28" rx="4" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.5"/>
<text x="80" y="73" text-anchor="middle" fill="#4a9eff" font-family="monospace" font-size="11">token 0 → expert 0</text>

<!-- Token 1 → expert 1 (green) -->
<rect x="20" y="90" width="120" height="28" rx="4" fill="#1a3a2a" stroke="#4adf7f" stroke-width="1.5"/>
<text x="80" y="108" text-anchor="middle" fill="#4adf7f" font-family="monospace" font-size="11">token 1 → expert 1</text>

<!-- Token 2 → expert 0 (blue) -->
<rect x="20" y="125" width="120" height="28" rx="4" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.5"/>
<text x="80" y="143" text-anchor="middle" fill="#4a9eff" font-family="monospace" font-size="11">token 2 → expert 0</text>

<!-- Token 3 → expert 1 (green) -->
<rect x="20" y="160" width="120" height="28" rx="4" fill="#1a3a2a" stroke="#4adf7f" stroke-width="1.5"/>
<text x="80" y="178" text-anchor="middle" fill="#4adf7f" font-family="monospace" font-size="11">token 3 → expert 1</text>

<!-- Arrows for expert 0 tokens -->
<line x1="140" y1="69" x2="195" y2="90" stroke="#4a9eff" stroke-width="1.5" marker-end="url(#arrowblue)"/>
<line x1="140" y1="139" x2="195" y2="110" stroke="#4a9eff" stroke-width="1.5" marker-end="url(#arrowblue)"/>

<!-- Arrows for expert 1 tokens -->
<line x1="140" y1="104" x2="195" y2="155" stroke="#4adf7f" stroke-width="1.5" marker-end="url(#arrowgreen)"/>
<line x1="140" y1="174" x2="195" y2="170" stroke="#4adf7f" stroke-width="1.5" marker-end="url(#arrowgreen)"/>

<!-- Arrow markers -->
<defs>
<marker id="arrowblue" markerWidth="8" markerHeight="8" refX="6" refY="3" orient="auto">
<path d="M0,0 L0,6 L8,3 z" fill="#4a9eff"/>
</marker>
<marker id="arrowgreen" markerWidth="8" markerHeight="8" refX="6" refY="3" orient="auto">
<path d="M0,0 L0,6 L8,3 z" fill="#4adf7f"/>
</marker>
</defs>

<!-- Right: expert 0 buffer -->
<text x="370" y="48" text-anchor="middle" fill="#aaa" font-family="monospace" font-size="11">dispatched_x [E, capacity, D]</text>

<text x="280" y="78" fill="#4a9eff" font-family="monospace" font-size="10">expert 0</text>
<rect x="200" y="83" width="160" height="24" rx="3" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.2"/>
<text x="280" y="99" text-anchor="middle" fill="#4a9eff" font-family="monospace" font-size="10">slot 0: token 0 features</text>
<rect x="200" y="108" width="160" height="24" rx="3" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.2"/>
<text x="280" y="124" text-anchor="middle" fill="#4a9eff" font-family="monospace" font-size="10">slot 1: token 2 features</text>

<!-- Right: expert 1 buffer -->
<text x="280" y="148" fill="#4adf7f" font-family="monospace" font-size="10">expert 1</text>
<rect x="200" y="153" width="160" height="24" rx="3" fill="#1a3a2a" stroke="#4adf7f" stroke-width="1.2"/>
<text x="280" y="169" text-anchor="middle" fill="#4adf7f" font-family="monospace" font-size="10">slot 0: token 1 features</text>
<rect x="200" y="178" width="160" height="24" rx="3" fill="#1a3a2a" stroke="#4adf7f" stroke-width="1.2"/>
<text x="280" y="194" text-anchor="middle" fill="#4adf7f" font-family="monospace" font-size="10">slot 1: token 3 features</text>

<!-- token_counts label -->
<text x="280" y="218" text-anchor="middle" fill="#888" font-family="monospace" font-size="10">token_counts = [2, 2]</text>
</svg>

<h2>Implementation Requirements</h2>
<p>
Your <code>solve</code> function signature must not be changed. No external libraries beyond
those already imported. Write outputs to <code>dispatched_x</code> and
<code>token_counts</code> in place.
</p>
<p>
For each expert <code>e</code>, <code>dispatched_x[e, :token_counts[e], :]</code> must
contain exactly the tokens assigned to expert <code>e</code>, in the same relative order
they appear in the original token sequence (i.e., token index ascending). Entries at
positions <code>token_counts[e]</code> and beyond may be left as zero.
</p>

<h2>Example</h2>
<p>
Given <code>T</code> = 4 tokens of dimension <code>D</code> = 3 routed to
<code>E</code> = 2 experts (capacity = 4):
</p>

<p>Input tokens \(x\):</p>
<p>$$x = \begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \\ 1 & 1 & 0 \end{bmatrix}$$</p>

<p>Expert assignments:</p>
<pre>expert_idx = [0, 1, 0, 1]</pre>

<p>Output token counts:</p>
<pre>token_counts = [2, 2]</pre>

<p>\(\texttt{dispatched\_x}[0]\) &mdash; expert 0 receives tokens 0 and 2 (in original order):</p>
<p>$$\begin{bmatrix} 1 & 0 & 0 \\ 0 & 0 & 1 \end{bmatrix}$$</p>

<p>\(\texttt{dispatched\_x}[1]\) &mdash; expert 1 receives tokens 1 and 3 (in original order):</p>
<p>$$\begin{bmatrix} 0 & 1 & 0 \\ 1 & 1 & 0 \end{bmatrix}$$</p>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>T</code> &le; 65,536</li>
<li>1 &le; <code>D</code> &le; 2,048</li>
<li>2 &le; <code>E</code> &le; 64</li>
<li><code>capacity</code> = <code>T</code> (always sufficient for any routing assignment)</li>
<li>All values in <code>expert_idx</code> are in the range [0, <code>E</code> &minus; 1]</li>
<li><code>x</code> values are <code>float32</code>; <code>expert_idx</code> and <code>token_counts</code> are <code>int32</code></li>
<li>Performance is measured with <code>T</code> = 16,384, <code>D</code> = 512, <code>E</code> = 8</li>
</ul>
200 changes: 200 additions & 0 deletions challenges/medium/91_moe_token_dispatch/challenge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import ctypes
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase


class Challenge(ChallengeBase):
def __init__(self):
super().__init__(
name="MoE Token Dispatch",
atol=1e-05,
rtol=1e-05,
num_gpus=1,
access_tier="free",
)

def reference_impl(
self,
x: torch.Tensor,
expert_idx: torch.Tensor,
dispatched_x: torch.Tensor,
token_counts: torch.Tensor,
T: int,
D: int,
E: int,
capacity: int,
):
assert x.shape == (T, D)
assert expert_idx.shape == (T,)
assert dispatched_x.shape == (E, capacity, D)
assert token_counts.shape == (E,)
assert x.dtype == torch.float32
assert expert_idx.dtype == torch.int32
assert dispatched_x.dtype == torch.float32
assert token_counts.dtype == torch.int32
assert x.device.type == "cuda"
assert expert_idx.device.type == "cuda"

for e in range(E):
# torch.where returns indices in ascending order — stable within each expert
indices = torch.where(expert_idx == e)[0]
count = int(indices.shape[0])
assert count <= capacity, f"Expert {e} has {count} tokens but capacity is {capacity}"
dispatched_x[e, :count] = x[indices]
token_counts[e] = count

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"x": (ctypes.POINTER(ctypes.c_float), "in"),
"expert_idx": (ctypes.POINTER(ctypes.c_int), "in"),
"dispatched_x": (ctypes.POINTER(ctypes.c_float), "out"),
"token_counts": (ctypes.POINTER(ctypes.c_int), "out"),
"T": (ctypes.c_int, "in"),
"D": (ctypes.c_int, "in"),
"E": (ctypes.c_int, "in"),
"capacity": (ctypes.c_int, "in"),
}

def _make_test(
self,
T: int,
D: int,
E: int,
expert_idx_tensor: torch.Tensor = None,
seed: int = 42,
) -> Dict[str, Any]:
torch.manual_seed(seed)
capacity = T
x = torch.randn(T, D, device="cuda", dtype=torch.float32)
if expert_idx_tensor is not None:
expert_idx = expert_idx_tensor
else:
expert_idx = torch.randint(0, E, (T,), device="cuda", dtype=torch.int32)
dispatched_x = torch.zeros(E, capacity, D, device="cuda", dtype=torch.float32)
token_counts = torch.zeros(E, device="cuda", dtype=torch.int32)
return {
"x": x,
"expert_idx": expert_idx,
"dispatched_x": dispatched_x,
"token_counts": token_counts,
"T": T,
"D": D,
"E": E,
"capacity": capacity,
}

def generate_example_test(self) -> Dict[str, Any]:
T, D, E = 4, 3, 2
capacity = T
x = torch.tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 1.0, 0.0]],
device="cuda",
dtype=torch.float32,
)
expert_idx = torch.tensor([0, 1, 0, 1], device="cuda", dtype=torch.int32)
dispatched_x = torch.zeros(E, capacity, D, device="cuda", dtype=torch.float32)
token_counts = torch.zeros(E, device="cuda", dtype=torch.int32)
return {
"x": x,
"expert_idx": expert_idx,
"dispatched_x": dispatched_x,
"token_counts": token_counts,
"T": T,
"D": D,
"E": E,
"capacity": capacity,
}

def generate_functional_test(self) -> List[Dict[str, Any]]:
tests = []

# Edge case: single token goes to expert 0, expert 1 is empty
tests.append(
self._make_test(
1,
4,
2,
expert_idx_tensor=torch.tensor([0], device="cuda", dtype=torch.int32),
)
)

# Edge case: two tokens, both assigned to expert 0
tests.append(
self._make_test(
2,
4,
2,
expert_idx_tensor=torch.tensor([0, 0], device="cuda", dtype=torch.int32),
)
)

# Edge case: exactly one token per expert (T == E)
tests.append(
self._make_test(
4,
8,
4,
expert_idx_tensor=torch.tensor([0, 1, 2, 3], device="cuda", dtype=torch.int32),
seed=1,
)
)

# Skewed distribution: 6 of 8 tokens go to expert 0
skewed = torch.tensor([0, 0, 0, 0, 0, 0, 1, 2], device="cuda", dtype=torch.int32)
tests.append(self._make_test(8, 8, 4, expert_idx_tensor=skewed, seed=10))

# Power-of-2: T=32, cycling uniformly through 4 experts
uniform32 = (torch.arange(32, device="cuda") % 4).to(torch.int32)
tests.append(self._make_test(32, 16, 4, expert_idx_tensor=uniform32, seed=2))

# Power-of-2: T=256, random assignments to 8 experts
tests.append(self._make_test(256, 64, 8, seed=3))

# Non-power-of-2: T=30, cycling uniformly
uniform30 = (torch.arange(30, device="cuda") % 4).to(torch.int32)
tests.append(self._make_test(30, 16, 4, expert_idx_tensor=uniform30, seed=4))

# Non-power-of-2: T=100, random assignments to 6 experts (includes negatives in x)
tests.append(self._make_test(100, 32, 6, seed=5))

# Realistic: T=1024 tokens, D=128, E=8 (includes negatives in x)
tests.append(self._make_test(1024, 128, 8, seed=6))

# Zero x values, random routing
torch.manual_seed(7)
zero_x = torch.zeros(64, 32, device="cuda", dtype=torch.float32)
tests.append(
{
"x": zero_x,
"expert_idx": torch.randint(0, 4, (64,), device="cuda", dtype=torch.int32),
"dispatched_x": torch.zeros(4, 64, 32, device="cuda", dtype=torch.float32),
"token_counts": torch.zeros(4, device="cuda", dtype=torch.int32),
"T": 64,
"D": 32,
"E": 4,
"capacity": 64,
}
)

return tests

def generate_performance_test(self) -> Dict[str, Any]:
T, D, E = 16384, 512, 8
capacity = T
torch.manual_seed(0)
x = torch.randn(T, D, device="cuda", dtype=torch.float32)
expert_idx = torch.randint(0, E, (T,), device="cuda", dtype=torch.int32)
dispatched_x = torch.zeros(E, capacity, D, device="cuda", dtype=torch.float32)
token_counts = torch.zeros(E, device="cuda", dtype=torch.int32)
return {
"x": x,
"expert_idx": expert_idx,
"dispatched_x": dispatched_x,
"token_counts": token_counts,
"T": T,
"D": D,
"E": E,
"capacity": capacity,
}
5 changes: 5 additions & 0 deletions challenges/medium/91_moe_token_dispatch/starter/starter.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <cuda_runtime.h>

// x, expert_idx, dispatched_x, token_counts are device pointers
extern "C" void solve(const float* x, const int* expert_idx, float* dispatched_x, int* token_counts,
int T, int D, int E, int capacity) {}
17 changes: 17 additions & 0 deletions challenges/medium/91_moe_token_dispatch/starter/starter.cute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import cutlass
import cutlass.cute as cute


# x, expert_idx, dispatched_x, token_counts are tensors on the GPU
@cute.jit
def solve(
x: cute.Tensor,
expert_idx: cute.Tensor,
dispatched_x: cute.Tensor,
token_counts: cute.Tensor,
T: cute.Int32,
D: cute.Int32,
E: cute.Int32,
capacity: cute.Int32,
):
pass
15 changes: 15 additions & 0 deletions challenges/medium/91_moe_token_dispatch/starter/starter.jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import jax
import jax.numpy as jnp


# x, expert_idx are tensors on GPU
@jax.jit
def solve(
x: jax.Array,
expert_idx: jax.Array,
T: int,
D: int,
E: int,
capacity: int,
) -> tuple[jax.Array, jax.Array]:
pass
19 changes: 19 additions & 0 deletions challenges/medium/91_moe_token_dispatch/starter/starter.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from std.gpu.host import DeviceContext
from std.gpu import block_dim, block_idx, thread_idx
from std.memory import UnsafePointer
from std.math import ceildiv


# x, expert_idx, dispatched_x, token_counts are device pointers
@export
def solve(
x: UnsafePointer[Float32, MutExternalOrigin],
expert_idx: UnsafePointer[Int32, MutExternalOrigin],
dispatched_x: UnsafePointer[Float32, MutExternalOrigin],
token_counts: UnsafePointer[Int32, MutExternalOrigin],
T: Int32,
D: Int32,
E: Int32,
capacity: Int32,
) raises:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch


# x, expert_idx, dispatched_x, token_counts are tensors on the GPU
def solve(
x: torch.Tensor,
expert_idx: torch.Tensor,
dispatched_x: torch.Tensor,
token_counts: torch.Tensor,
T: int,
D: int,
E: int,
capacity: int,
):
pass
Loading
Loading