Skip to content

Commit c3cdc50

Browse files
Add challenge 91: MoE Token Dispatch (Medium)
Teaches stable scatter / parallel dispatch — a key MoE inference building block that follows naturally from challenge 67 (MoE Top-K Gating). Solvers must pack T tokens into per-expert buffers [E, capacity, D], preserving original token order within each expert group. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent ca891d3 commit c3cdc50

8 files changed

Lines changed: 404 additions & 0 deletions

File tree

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
<p>
2+
In Mixture-of-Experts (MoE) neural networks, a router assigns each input token to one of
3+
<code>E</code> expert networks. Before those experts can execute in parallel, tokens must be
4+
<em>dispatched</em>: reorganized from their original sequence order into contiguous per-expert
5+
buffers. Given <code>T</code> tokens of feature dimension <code>D</code> and their expert
6+
assignments, fill each expert's slice of the output buffer with the tokens routed to that
7+
expert, preserving the original token order within each expert's group.
8+
</p>
9+
10+
<svg width="540" height="230" viewBox="0 0 540 230" style="display:block; margin:20px auto;" xmlns="http://www.w3.org/2000/svg">
11+
<rect width="540" height="230" rx="8" fill="#222"/>
12+
13+
<!-- Title -->
14+
<text x="270" y="22" text-anchor="middle" fill="#ccc" font-family="monospace" font-size="12">MoE Token Dispatch (T=4, E=2)</text>
15+
16+
<!-- Left: input tokens -->
17+
<text x="80" y="48" text-anchor="middle" fill="#aaa" font-family="monospace" font-size="11">x [T, D]</text>
18+
19+
<!-- Token 0 → expert 0 (blue) -->
20+
<rect x="20" y="55" width="120" height="28" rx="4" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.5"/>
21+
<text x="80" y="73" text-anchor="middle" fill="#4a9eff" font-family="monospace" font-size="11">token 0 → expert 0</text>
22+
23+
<!-- Token 1 → expert 1 (green) -->
24+
<rect x="20" y="90" width="120" height="28" rx="4" fill="#1a3a2a" stroke="#4adf7f" stroke-width="1.5"/>
25+
<text x="80" y="108" text-anchor="middle" fill="#4adf7f" font-family="monospace" font-size="11">token 1 → expert 1</text>
26+
27+
<!-- Token 2 → expert 0 (blue) -->
28+
<rect x="20" y="125" width="120" height="28" rx="4" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.5"/>
29+
<text x="80" y="143" text-anchor="middle" fill="#4a9eff" font-family="monospace" font-size="11">token 2 → expert 0</text>
30+
31+
<!-- Token 3 → expert 1 (green) -->
32+
<rect x="20" y="160" width="120" height="28" rx="4" fill="#1a3a2a" stroke="#4adf7f" stroke-width="1.5"/>
33+
<text x="80" y="178" text-anchor="middle" fill="#4adf7f" font-family="monospace" font-size="11">token 3 → expert 1</text>
34+
35+
<!-- Arrows for expert 0 tokens -->
36+
<line x1="140" y1="69" x2="195" y2="90" stroke="#4a9eff" stroke-width="1.5" marker-end="url(#arrowblue)"/>
37+
<line x1="140" y1="139" x2="195" y2="110" stroke="#4a9eff" stroke-width="1.5" marker-end="url(#arrowblue)"/>
38+
39+
<!-- Arrows for expert 1 tokens -->
40+
<line x1="140" y1="104" x2="195" y2="155" stroke="#4adf7f" stroke-width="1.5" marker-end="url(#arrowgreen)"/>
41+
<line x1="140" y1="174" x2="195" y2="170" stroke="#4adf7f" stroke-width="1.5" marker-end="url(#arrowgreen)"/>
42+
43+
<!-- Arrow markers -->
44+
<defs>
45+
<marker id="arrowblue" markerWidth="8" markerHeight="8" refX="6" refY="3" orient="auto">
46+
<path d="M0,0 L0,6 L8,3 z" fill="#4a9eff"/>
47+
</marker>
48+
<marker id="arrowgreen" markerWidth="8" markerHeight="8" refX="6" refY="3" orient="auto">
49+
<path d="M0,0 L0,6 L8,3 z" fill="#4adf7f"/>
50+
</marker>
51+
</defs>
52+
53+
<!-- Right: expert 0 buffer -->
54+
<text x="370" y="48" text-anchor="middle" fill="#aaa" font-family="monospace" font-size="11">dispatched_x [E, capacity, D]</text>
55+
56+
<text x="280" y="78" fill="#4a9eff" font-family="monospace" font-size="10">expert 0</text>
57+
<rect x="200" y="83" width="160" height="24" rx="3" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.2"/>
58+
<text x="280" y="99" text-anchor="middle" fill="#4a9eff" font-family="monospace" font-size="10">slot 0: token 0 features</text>
59+
<rect x="200" y="108" width="160" height="24" rx="3" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.2"/>
60+
<text x="280" y="124" text-anchor="middle" fill="#4a9eff" font-family="monospace" font-size="10">slot 1: token 2 features</text>
61+
62+
<!-- Right: expert 1 buffer -->
63+
<text x="280" y="148" fill="#4adf7f" font-family="monospace" font-size="10">expert 1</text>
64+
<rect x="200" y="153" width="160" height="24" rx="3" fill="#1a3a2a" stroke="#4adf7f" stroke-width="1.2"/>
65+
<text x="280" y="169" text-anchor="middle" fill="#4adf7f" font-family="monospace" font-size="10">slot 0: token 1 features</text>
66+
<rect x="200" y="178" width="160" height="24" rx="3" fill="#1a3a2a" stroke="#4adf7f" stroke-width="1.2"/>
67+
<text x="280" y="194" text-anchor="middle" fill="#4adf7f" font-family="monospace" font-size="10">slot 1: token 3 features</text>
68+
69+
<!-- token_counts label -->
70+
<text x="280" y="218" text-anchor="middle" fill="#888" font-family="monospace" font-size="10">token_counts = [2, 2]</text>
71+
</svg>
72+
73+
<h2>Implementation Requirements</h2>
74+
<p>
75+
Your <code>solve</code> function signature must not be changed. No external libraries beyond
76+
those already imported. Write outputs to <code>dispatched_x</code> and
77+
<code>token_counts</code> in place.
78+
</p>
79+
<p>
80+
For each expert <code>e</code>, <code>dispatched_x[e, :token_counts[e], :]</code> must
81+
contain exactly the tokens assigned to expert <code>e</code>, in the same relative order
82+
they appear in the original token sequence (i.e., token index ascending). Entries at
83+
positions <code>token_counts[e]</code> and beyond may be left as zero.
84+
</p>
85+
86+
<h2>Example</h2>
87+
<p>
88+
Given <code>T</code> = 4 tokens of dimension <code>D</code> = 3 routed to
89+
<code>E</code> = 2 experts (capacity = 4):
90+
</p>
91+
92+
<p>Input tokens \(x\):</p>
93+
<p>$$x = \begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \\ 1 & 1 & 0 \end{bmatrix}$$</p>
94+
95+
<p>Expert assignments:</p>
96+
<pre>expert_idx = [0, 1, 0, 1]</pre>
97+
98+
<p>Output token counts:</p>
99+
<pre>token_counts = [2, 2]</pre>
100+
101+
<p>\(\texttt{dispatched\_x}[0]\) &mdash; expert 0 receives tokens 0 and 2 (in original order):</p>
102+
<p>$$\begin{bmatrix} 1 & 0 & 0 \\ 0 & 0 & 1 \end{bmatrix}$$</p>
103+
104+
<p>\(\texttt{dispatched\_x}[1]\) &mdash; expert 1 receives tokens 1 and 3 (in original order):</p>
105+
<p>$$\begin{bmatrix} 0 & 1 & 0 \\ 1 & 1 & 0 \end{bmatrix}$$</p>
106+
107+
<h2>Constraints</h2>
108+
<ul>
109+
<li>1 &le; <code>T</code> &le; 65,536</li>
110+
<li>1 &le; <code>D</code> &le; 2,048</li>
111+
<li>2 &le; <code>E</code> &le; 64</li>
112+
<li><code>capacity</code> = <code>T</code> (always sufficient for any routing assignment)</li>
113+
<li>All values in <code>expert_idx</code> are in the range [0, <code>E</code> &minus; 1]</li>
114+
<li><code>x</code> values are <code>float32</code>; <code>expert_idx</code> and <code>token_counts</code> are <code>int32</code></li>
115+
<li>Performance is measured with <code>T</code> = 16,384, <code>D</code> = 512, <code>E</code> = 8</li>
116+
</ul>
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import ctypes
2+
from typing import Any, Dict, List
3+
4+
import torch
5+
from core.challenge_base import ChallengeBase
6+
7+
8+
class Challenge(ChallengeBase):
9+
def __init__(self):
10+
super().__init__(
11+
name="MoE Token Dispatch",
12+
atol=1e-05,
13+
rtol=1e-05,
14+
num_gpus=1,
15+
access_tier="free",
16+
)
17+
18+
def reference_impl(
19+
self,
20+
x: torch.Tensor,
21+
expert_idx: torch.Tensor,
22+
dispatched_x: torch.Tensor,
23+
token_counts: torch.Tensor,
24+
T: int,
25+
D: int,
26+
E: int,
27+
capacity: int,
28+
):
29+
assert x.shape == (T, D)
30+
assert expert_idx.shape == (T,)
31+
assert dispatched_x.shape == (E, capacity, D)
32+
assert token_counts.shape == (E,)
33+
assert x.dtype == torch.float32
34+
assert expert_idx.dtype == torch.int32
35+
assert dispatched_x.dtype == torch.float32
36+
assert token_counts.dtype == torch.int32
37+
assert x.device.type == "cuda"
38+
assert expert_idx.device.type == "cuda"
39+
40+
for e in range(E):
41+
# torch.where returns indices in ascending order — stable within each expert
42+
indices = torch.where(expert_idx == e)[0]
43+
count = int(indices.shape[0])
44+
assert count <= capacity, f"Expert {e} has {count} tokens but capacity is {capacity}"
45+
dispatched_x[e, :count] = x[indices]
46+
token_counts[e] = count
47+
48+
def get_solve_signature(self) -> Dict[str, tuple]:
49+
return {
50+
"x": (ctypes.POINTER(ctypes.c_float), "in"),
51+
"expert_idx": (ctypes.POINTER(ctypes.c_int), "in"),
52+
"dispatched_x": (ctypes.POINTER(ctypes.c_float), "out"),
53+
"token_counts": (ctypes.POINTER(ctypes.c_int), "out"),
54+
"T": (ctypes.c_int, "in"),
55+
"D": (ctypes.c_int, "in"),
56+
"E": (ctypes.c_int, "in"),
57+
"capacity": (ctypes.c_int, "in"),
58+
}
59+
60+
def _make_test(
61+
self,
62+
T: int,
63+
D: int,
64+
E: int,
65+
expert_idx_tensor: torch.Tensor = None,
66+
seed: int = 42,
67+
) -> Dict[str, Any]:
68+
torch.manual_seed(seed)
69+
capacity = T
70+
x = torch.randn(T, D, device="cuda", dtype=torch.float32)
71+
if expert_idx_tensor is not None:
72+
expert_idx = expert_idx_tensor
73+
else:
74+
expert_idx = torch.randint(0, E, (T,), device="cuda", dtype=torch.int32)
75+
dispatched_x = torch.zeros(E, capacity, D, device="cuda", dtype=torch.float32)
76+
token_counts = torch.zeros(E, device="cuda", dtype=torch.int32)
77+
return {
78+
"x": x,
79+
"expert_idx": expert_idx,
80+
"dispatched_x": dispatched_x,
81+
"token_counts": token_counts,
82+
"T": T,
83+
"D": D,
84+
"E": E,
85+
"capacity": capacity,
86+
}
87+
88+
def generate_example_test(self) -> Dict[str, Any]:
89+
T, D, E = 4, 3, 2
90+
capacity = T
91+
x = torch.tensor(
92+
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 1.0, 0.0]],
93+
device="cuda",
94+
dtype=torch.float32,
95+
)
96+
expert_idx = torch.tensor([0, 1, 0, 1], device="cuda", dtype=torch.int32)
97+
dispatched_x = torch.zeros(E, capacity, D, device="cuda", dtype=torch.float32)
98+
token_counts = torch.zeros(E, device="cuda", dtype=torch.int32)
99+
return {
100+
"x": x,
101+
"expert_idx": expert_idx,
102+
"dispatched_x": dispatched_x,
103+
"token_counts": token_counts,
104+
"T": T,
105+
"D": D,
106+
"E": E,
107+
"capacity": capacity,
108+
}
109+
110+
def generate_functional_test(self) -> List[Dict[str, Any]]:
111+
tests = []
112+
113+
# Edge case: single token goes to expert 0, expert 1 is empty
114+
tests.append(
115+
self._make_test(
116+
1,
117+
4,
118+
2,
119+
expert_idx_tensor=torch.tensor([0], device="cuda", dtype=torch.int32),
120+
)
121+
)
122+
123+
# Edge case: two tokens, both assigned to expert 0
124+
tests.append(
125+
self._make_test(
126+
2,
127+
4,
128+
2,
129+
expert_idx_tensor=torch.tensor([0, 0], device="cuda", dtype=torch.int32),
130+
)
131+
)
132+
133+
# Edge case: exactly one token per expert (T == E)
134+
tests.append(
135+
self._make_test(
136+
4,
137+
8,
138+
4,
139+
expert_idx_tensor=torch.tensor([0, 1, 2, 3], device="cuda", dtype=torch.int32),
140+
seed=1,
141+
)
142+
)
143+
144+
# Skewed distribution: 6 of 8 tokens go to expert 0
145+
skewed = torch.tensor([0, 0, 0, 0, 0, 0, 1, 2], device="cuda", dtype=torch.int32)
146+
tests.append(self._make_test(8, 8, 4, expert_idx_tensor=skewed, seed=10))
147+
148+
# Power-of-2: T=32, cycling uniformly through 4 experts
149+
uniform32 = (torch.arange(32, device="cuda") % 4).to(torch.int32)
150+
tests.append(self._make_test(32, 16, 4, expert_idx_tensor=uniform32, seed=2))
151+
152+
# Power-of-2: T=256, random assignments to 8 experts
153+
tests.append(self._make_test(256, 64, 8, seed=3))
154+
155+
# Non-power-of-2: T=30, cycling uniformly
156+
uniform30 = (torch.arange(30, device="cuda") % 4).to(torch.int32)
157+
tests.append(self._make_test(30, 16, 4, expert_idx_tensor=uniform30, seed=4))
158+
159+
# Non-power-of-2: T=100, random assignments to 6 experts (includes negatives in x)
160+
tests.append(self._make_test(100, 32, 6, seed=5))
161+
162+
# Realistic: T=1024 tokens, D=128, E=8 (includes negatives in x)
163+
tests.append(self._make_test(1024, 128, 8, seed=6))
164+
165+
# Zero x values, random routing
166+
torch.manual_seed(7)
167+
zero_x = torch.zeros(64, 32, device="cuda", dtype=torch.float32)
168+
tests.append(
169+
{
170+
"x": zero_x,
171+
"expert_idx": torch.randint(0, 4, (64,), device="cuda", dtype=torch.int32),
172+
"dispatched_x": torch.zeros(4, 64, 32, device="cuda", dtype=torch.float32),
173+
"token_counts": torch.zeros(4, device="cuda", dtype=torch.int32),
174+
"T": 64,
175+
"D": 32,
176+
"E": 4,
177+
"capacity": 64,
178+
}
179+
)
180+
181+
return tests
182+
183+
def generate_performance_test(self) -> Dict[str, Any]:
184+
T, D, E = 16384, 512, 8
185+
capacity = T
186+
torch.manual_seed(0)
187+
x = torch.randn(T, D, device="cuda", dtype=torch.float32)
188+
expert_idx = torch.randint(0, E, (T,), device="cuda", dtype=torch.int32)
189+
dispatched_x = torch.zeros(E, capacity, D, device="cuda", dtype=torch.float32)
190+
token_counts = torch.zeros(E, device="cuda", dtype=torch.int32)
191+
return {
192+
"x": x,
193+
"expert_idx": expert_idx,
194+
"dispatched_x": dispatched_x,
195+
"token_counts": token_counts,
196+
"T": T,
197+
"D": D,
198+
"E": E,
199+
"capacity": capacity,
200+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include <cuda_runtime.h>
2+
3+
// x, expert_idx, dispatched_x, token_counts are device pointers
4+
extern "C" void solve(const float* x, const int* expert_idx, float* dispatched_x, int* token_counts,
5+
int T, int D, int E, int capacity) {}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import cutlass
2+
import cutlass.cute as cute
3+
4+
5+
# x, expert_idx, dispatched_x, token_counts are tensors on the GPU
6+
@cute.jit
7+
def solve(
8+
x: cute.Tensor,
9+
expert_idx: cute.Tensor,
10+
dispatched_x: cute.Tensor,
11+
token_counts: cute.Tensor,
12+
T: cute.Int32,
13+
D: cute.Int32,
14+
E: cute.Int32,
15+
capacity: cute.Int32,
16+
):
17+
pass
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
5+
# x, expert_idx are tensors on GPU
6+
@jax.jit
7+
def solve(
8+
x: jax.Array,
9+
expert_idx: jax.Array,
10+
T: int,
11+
D: int,
12+
E: int,
13+
capacity: int,
14+
) -> tuple[jax.Array, jax.Array]:
15+
pass
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from std.gpu.host import DeviceContext
2+
from std.gpu import block_dim, block_idx, thread_idx
3+
from std.memory import UnsafePointer
4+
from std.math import ceildiv
5+
6+
7+
# x, expert_idx, dispatched_x, token_counts are device pointers
8+
@export
9+
def solve(
10+
x: UnsafePointer[Float32, MutExternalOrigin],
11+
expert_idx: UnsafePointer[Int32, MutExternalOrigin],
12+
dispatched_x: UnsafePointer[Float32, MutExternalOrigin],
13+
token_counts: UnsafePointer[Int32, MutExternalOrigin],
14+
T: Int32,
15+
D: Int32,
16+
E: Int32,
17+
capacity: Int32,
18+
) raises:
19+
pass
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
3+
4+
# x, expert_idx, dispatched_x, token_counts are tensors on the GPU
5+
def solve(
6+
x: torch.Tensor,
7+
expert_idx: torch.Tensor,
8+
dispatched_x: torch.Tensor,
9+
token_counts: torch.Tensor,
10+
T: int,
11+
D: int,
12+
E: int,
13+
capacity: int,
14+
):
15+
pass

0 commit comments

Comments
 (0)