Skip to content

Commit 721e141

Browse files
Add challenge 93: Llama Transformer Block (Hard)
Adds a complete Llama-style transformer decoder block challenge that requires implementing RMSNorm, Grouped Query Attention with RoPE, causal masking, and a SwiGLU FFN — mirroring modern LLM inference kernels (Llama 2/3, Mistral, Gemma). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent ca891d3 commit 721e141

8 files changed

Lines changed: 526 additions & 0 deletions

File tree

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
<p>
2+
Implement a single Llama-style transformer decoder block. Given an input tensor \(x\) of shape
3+
<code>(seq_len, 512)</code>, a packed weight buffer, and precomputed RoPE tables, compute the
4+
output using pre-norm architecture with Grouped Query Attention (GQA), Rotary Position Embeddings
5+
(RoPE), and a SwiGLU feed-forward network.
6+
</p>
7+
8+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 340 660" width="340" height="660" style="display:block; margin:20px auto;">
9+
<defs>
10+
<marker id="ah" viewBox="0 0 10 10" refX="9" refY="5" markerWidth="6" markerHeight="6" orient="auto-start-reverse">
11+
<path d="M0 0L10 5L0 10z" fill="#999"/>
12+
</marker>
13+
</defs>
14+
<rect width="340" height="660" fill="#222"/>
15+
16+
<!-- Input label -->
17+
<text x="140" y="20" text-anchor="middle" fill="#ccc" font-size="13" font-family="monospace">x (seq_len, 512)</text>
18+
19+
<!-- Arrow: input -> RMSNorm1 -->
20+
<line x1="140" y1="28" x2="140" y2="46" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
21+
22+
<!-- Residual 1: fork right, down, back left to Add1 -->
23+
<line x1="140" y1="36" x2="280" y2="36" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/>
24+
<line x1="280" y1="36" x2="280" y2="306" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/>
25+
<line x1="280" y1="306" x2="157" y2="306" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4" marker-end="url(#ah)"/>
26+
<text x="290" y="175" fill="#666" font-size="10" font-family="sans-serif" transform="rotate(90,290,175)">residual</text>
27+
28+
<!-- RMSNorm1 -->
29+
<rect x="60" y="49" width="160" height="30" rx="5" fill="#333" stroke="#777" stroke-width="1"/>
30+
<text x="140" y="69" text-anchor="middle" fill="#ccc" font-size="12" font-family="sans-serif">RMSNorm 1</text>
31+
<line x1="140" y1="79" x2="140" y2="97" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
32+
33+
<!-- QKV Proj -->
34+
<rect x="60" y="100" width="160" height="30" rx="5" fill="#1e2d4d" stroke="#4477bb" stroke-width="1"/>
35+
<text x="140" y="120" text-anchor="middle" fill="#aaccee" font-size="12" font-family="sans-serif">QKV Projection (GQA)</text>
36+
<line x1="140" y1="130" x2="140" y2="148" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
37+
38+
<!-- RoPE -->
39+
<rect x="60" y="151" width="160" height="30" rx="5" fill="#2d1e4d" stroke="#7755bb" stroke-width="1"/>
40+
<text x="140" y="171" text-anchor="middle" fill="#ccaaee" font-size="12" font-family="sans-serif">RoPE (Q and K)</text>
41+
<line x1="140" y1="181" x2="140" y2="199" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
42+
43+
<!-- Causal Attn -->
44+
<rect x="60" y="202" width="160" height="30" rx="5" fill="#1e2d4d" stroke="#4477bb" stroke-width="1"/>
45+
<text x="140" y="222" text-anchor="middle" fill="#aaccee" font-size="12" font-family="sans-serif">Causal Attention</text>
46+
<line x1="140" y1="232" x2="140" y2="250" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
47+
48+
<!-- Output Proj -->
49+
<rect x="60" y="253" width="160" height="30" rx="5" fill="#1e2d4d" stroke="#4477bb" stroke-width="1"/>
50+
<text x="140" y="273" text-anchor="middle" fill="#aaccee" font-size="12" font-family="sans-serif">Output Projection</text>
51+
<line x1="140" y1="283" x2="140" y2="294" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
52+
53+
<!-- Add 1 -->
54+
<circle cx="140" cy="306" r="12" fill="#222" stroke="#999" stroke-width="1.5"/>
55+
<text x="140" y="311" text-anchor="middle" fill="#ccc" font-size="15" font-family="sans-serif" font-weight="bold">+</text>
56+
<line x1="140" y1="318" x2="140" y2="342" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
57+
58+
<!-- Residual 2 -->
59+
<line x1="140" y1="330" x2="280" y2="330" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/>
60+
<line x1="280" y1="330" x2="280" y2="586" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/>
61+
<line x1="280" y1="586" x2="157" y2="586" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4" marker-end="url(#ah)"/>
62+
<text x="290" y="460" fill="#666" font-size="10" font-family="sans-serif" transform="rotate(90,290,460)">residual</text>
63+
64+
<!-- RMSNorm2 -->
65+
<rect x="60" y="345" width="160" height="30" rx="5" fill="#333" stroke="#777" stroke-width="1"/>
66+
<text x="140" y="365" text-anchor="middle" fill="#ccc" font-size="12" font-family="sans-serif">RMSNorm 2</text>
67+
<line x1="140" y1="375" x2="140" y2="393" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
68+
69+
<!-- Gate + Up Proj -->
70+
<rect x="60" y="396" width="160" height="30" rx="5" fill="#1e3d2d" stroke="#44aa66" stroke-width="1"/>
71+
<text x="140" y="416" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="sans-serif">Gate &amp; Up Proj (512→1408)</text>
72+
<line x1="140" y1="426" x2="140" y2="444" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
73+
74+
<!-- SiLU + multiply -->
75+
<rect x="60" y="447" width="160" height="30" rx="5" fill="#1e3d2d" stroke="#44aa66" stroke-width="1"/>
76+
<text x="140" y="467" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="sans-serif">SiLU(gate) &#x2299; up</text>
77+
<line x1="140" y1="477" x2="140" y2="495" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
78+
79+
<!-- Down Proj -->
80+
<rect x="60" y="498" width="160" height="30" rx="5" fill="#1e3d2d" stroke="#44aa66" stroke-width="1"/>
81+
<text x="140" y="518" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="sans-serif">Down Proj (1408→512)</text>
82+
<line x1="140" y1="528" x2="140" y2="574" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
83+
84+
<!-- Add 2 -->
85+
<circle cx="140" cy="586" r="12" fill="#222" stroke="#999" stroke-width="1.5"/>
86+
<text x="140" y="591" text-anchor="middle" fill="#ccc" font-size="15" font-family="sans-serif" font-weight="bold">+</text>
87+
<line x1="140" y1="598" x2="140" y2="622" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
88+
89+
<!-- Output label -->
90+
<text x="140" y="640" text-anchor="middle" fill="#ccc" font-size="13" font-family="monospace">output (seq_len, 512)</text>
91+
</svg>
92+
93+
<p>
94+
The block follows Llama's <strong>pre-norm</strong> architecture. Unlike GPT-2, it uses
95+
<strong>RMSNorm</strong> (no mean subtraction, no additive bias), <strong>Grouped Query
96+
Attention</strong> with 8 query heads and 2 key/value heads, <strong>Rotary Position
97+
Embeddings</strong> applied to Q and K, and a <strong>SwiGLU</strong> feed-forward network.
98+
None of the linear projections have bias terms.
99+
</p>
100+
101+
\[
102+
\begin{aligned}
103+
x' &= x + \text{Attn}\!\left(\text{RMSNorm}_1(x),\; \cos,\; \sin\right) \\[4pt]
104+
\text{output} &= x' + \text{FFN}\!\left(\text{RMSNorm}_2(x')\right)
105+
\end{aligned}
106+
\]
107+
108+
<p>The sub-operations in detail:</p>
109+
110+
\[
111+
\begin{aligned}
112+
\text{RMSNorm}(z, w) &= \frac{z}{\sqrt{\frac{1}{d}\sum_i z_i^2 + \varepsilon}} \odot w, \quad \varepsilon = 10^{-5} \\[8pt]
113+
Q &= \text{RMSNorm}_1(x)\, W_Q^\top \in \mathbb{R}^{T \times 512}, \quad \text{reshape to } (T, 8, 64) \\[4pt]
114+
K &= \text{RMSNorm}_1(x)\, W_K^\top \in \mathbb{R}^{T \times 128}, \quad \text{reshape to } (T, 2, 64) \\[4pt]
115+
V &= \text{RMSNorm}_1(x)\, W_V^\top \in \mathbb{R}^{T \times 128}, \quad \text{reshape to } (T, 2, 64) \\[8pt]
116+
\text{RoPE}(q, \cos, \sin) &: \quad [q_1 \mid q_2] \mapsto [q_1 \odot \cos - q_2 \odot \sin \mid q_1 \odot \sin + q_2 \odot \cos] \\[4pt]
117+
&\quad q_1 = q[\ldots, {:}32],\; q_2 = q[\ldots, {32:}] \\[8pt]
118+
\text{GQA} &: \text{repeat } K,V \text{ along head dim } 4\times \text{ to match 8 Q heads} \\[4pt]
119+
\text{head}_i &= \text{softmax}\!\left(\frac{Q_i K_i^\top}{\sqrt{64}} + M_{\text{causal}}\right) V_i \\[8pt]
120+
\text{Attn}(x) &= \text{Concat}(\text{head}_1, \ldots, \text{head}_8)\; W_O^\top \\[8pt]
121+
\text{FFN}(z) &= \bigl(\text{SiLU}(z\, W_{\text{gate}}^\top) \odot z\, W_{\text{up}}^\top\bigr)\; W_{\text{down}}^\top
122+
\end{aligned}
123+
\]
124+
125+
<p>where \(M_{\text{causal}}\) is the upper-triangular causal mask (\(-\infty\) above the diagonal)
126+
and \(\text{SiLU}(x) = x \cdot \sigma(x)\).</p>
127+
128+
<h2>Implementation Requirements</h2>
129+
<ul>
130+
<li>Use only native features (external libraries are not permitted)</li>
131+
<li>The <code>solve</code> function signature must remain unchanged</li>
132+
<li>The final result must be stored in the <code>output</code> tensor</li>
133+
<li>RMSNorm uses \(\varepsilon = 10^{-5}\), no additive bias</li>
134+
<li>Apply causal masking: position \(i\) attends only to positions \(\le i\)</li>
135+
<li>Repeat K and V heads \(4\times\) (GQA groups) before computing attention</li>
136+
<li><code>cos</code> and <code>sin</code> have shape <code>(seq_len, 32)</code> — apply
137+
them to both Q and K heads independently</li>
138+
</ul>
139+
140+
<h2>Weight Layout</h2>
141+
<p>All parameters are packed into a single contiguous <code>weights</code> buffer
142+
(2,819,072 floats) in the order below. All 2-D matrices are stored row-major
143+
with shape <code>(out_dim, in_dim)</code>. There are no bias terms.</p>
144+
145+
<table style="border-collapse:separate; border-spacing:16px 6px;">
146+
<tr>
147+
<th style="text-align:left;">Parameter</th>
148+
<th style="text-align:left;">Shape</th>
149+
<th style="text-align:right;">Size</th>
150+
<th style="text-align:right;">Offset</th>
151+
</tr>
152+
<tr>
153+
<td>\(w_1\) (RMSNorm 1 scale)</td>
154+
<td>(512,)</td>
155+
<td style="text-align:right;">512</td>
156+
<td style="text-align:right;">0</td>
157+
</tr>
158+
<tr>
159+
<td>\(W_Q\)</td>
160+
<td>(512, 512)</td>
161+
<td style="text-align:right;">262,144</td>
162+
<td style="text-align:right;">512</td>
163+
</tr>
164+
<tr>
165+
<td>\(W_K\)</td>
166+
<td>(128, 512)</td>
167+
<td style="text-align:right;">65,536</td>
168+
<td style="text-align:right;">262,656</td>
169+
</tr>
170+
<tr>
171+
<td>\(W_V\)</td>
172+
<td>(128, 512)</td>
173+
<td style="text-align:right;">65,536</td>
174+
<td style="text-align:right;">328,192</td>
175+
</tr>
176+
<tr>
177+
<td>\(W_O\)</td>
178+
<td>(512, 512)</td>
179+
<td style="text-align:right;">262,144</td>
180+
<td style="text-align:right;">393,728</td>
181+
</tr>
182+
<tr>
183+
<td>\(w_2\) (RMSNorm 2 scale)</td>
184+
<td>(512,)</td>
185+
<td style="text-align:right;">512</td>
186+
<td style="text-align:right;">655,872</td>
187+
</tr>
188+
<tr>
189+
<td>\(W_{\text{gate}}\)</td>
190+
<td>(1408, 512)</td>
191+
<td style="text-align:right;">720,896</td>
192+
<td style="text-align:right;">656,384</td>
193+
</tr>
194+
<tr>
195+
<td>\(W_{\text{up}}\)</td>
196+
<td>(1408, 512)</td>
197+
<td style="text-align:right;">720,896</td>
198+
<td style="text-align:right;">1,377,280</td>
199+
</tr>
200+
<tr>
201+
<td>\(W_{\text{down}}\)</td>
202+
<td>(512, 1408)</td>
203+
<td style="text-align:right;">720,896</td>
204+
<td style="text-align:right;">2,098,176</td>
205+
</tr>
206+
</table>
207+
208+
<h2>Example</h2>
209+
<p>With <code>seq_len</code> = 4, <code>x</code> drawn uniformly from [&minus;1, 1], and randomly
210+
initialized weights:</p>
211+
<pre>
212+
Input: x.shape = (4, 512) # 4 token hidden states
213+
weights.shape = (2,819,072,) # packed weight buffer
214+
cos.shape = (4, 32) # precomputed RoPE cosines
215+
sin.shape = (4, 32) # precomputed RoPE sines
216+
seq_len = 4
217+
Output: output.shape = (4, 512) # transformed token hidden states
218+
</pre>
219+
220+
<h2>Constraints</h2>
221+
<ul>
222+
<li><code>d_model</code> = 512, <code>n_q_heads</code> = 8, <code>n_kv_heads</code> = 2,
223+
<code>head_dim</code> = 64, <code>ffn_hidden</code> = 1,408</li>
224+
<li>1 &le; <code>seq_len</code> &le; 4,096</li>
225+
<li>All tensors use 32-bit floating point</li>
226+
<li>Performance is measured with <code>seq_len</code> = 2,048</li>
227+
</ul>

0 commit comments

Comments
 (0)