Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 227 additions & 0 deletions challenges/hard/93_llama_transformer_block/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
<p>
Implement a single Llama-style transformer decoder block. Given an input tensor \(x\) of shape
<code>(seq_len, 512)</code>, a packed weight buffer, and precomputed RoPE tables, compute the
output using pre-norm architecture with Grouped Query Attention (GQA), Rotary Position Embeddings
(RoPE), and a SwiGLU feed-forward network.
</p>

<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 340 660" width="340" height="660" style="display:block; margin:20px auto;">
<defs>
<marker id="ah" viewBox="0 0 10 10" refX="9" refY="5" markerWidth="6" markerHeight="6" orient="auto-start-reverse">
<path d="M0 0L10 5L0 10z" fill="#999"/>
</marker>
</defs>
<rect width="340" height="660" fill="#222"/>

<!-- Input label -->
<text x="140" y="20" text-anchor="middle" fill="#ccc" font-size="13" font-family="monospace">x (seq_len, 512)</text>

<!-- Arrow: input -> RMSNorm1 -->
<line x1="140" y1="28" x2="140" y2="46" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- Residual 1: fork right, down, back left to Add1 -->
<line x1="140" y1="36" x2="280" y2="36" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/>
<line x1="280" y1="36" x2="280" y2="306" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/>
<line x1="280" y1="306" x2="157" y2="306" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4" marker-end="url(#ah)"/>
<text x="290" y="175" fill="#666" font-size="10" font-family="sans-serif" transform="rotate(90,290,175)">residual</text>

<!-- RMSNorm1 -->
<rect x="60" y="49" width="160" height="30" rx="5" fill="#333" stroke="#777" stroke-width="1"/>
<text x="140" y="69" text-anchor="middle" fill="#ccc" font-size="12" font-family="sans-serif">RMSNorm 1</text>
<line x1="140" y1="79" x2="140" y2="97" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- QKV Proj -->
<rect x="60" y="100" width="160" height="30" rx="5" fill="#1e2d4d" stroke="#4477bb" stroke-width="1"/>
<text x="140" y="120" text-anchor="middle" fill="#aaccee" font-size="12" font-family="sans-serif">QKV Projection (GQA)</text>
<line x1="140" y1="130" x2="140" y2="148" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- RoPE -->
<rect x="60" y="151" width="160" height="30" rx="5" fill="#2d1e4d" stroke="#7755bb" stroke-width="1"/>
<text x="140" y="171" text-anchor="middle" fill="#ccaaee" font-size="12" font-family="sans-serif">RoPE (Q and K)</text>
<line x1="140" y1="181" x2="140" y2="199" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- Causal Attn -->
<rect x="60" y="202" width="160" height="30" rx="5" fill="#1e2d4d" stroke="#4477bb" stroke-width="1"/>
<text x="140" y="222" text-anchor="middle" fill="#aaccee" font-size="12" font-family="sans-serif">Causal Attention</text>
<line x1="140" y1="232" x2="140" y2="250" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- Output Proj -->
<rect x="60" y="253" width="160" height="30" rx="5" fill="#1e2d4d" stroke="#4477bb" stroke-width="1"/>
<text x="140" y="273" text-anchor="middle" fill="#aaccee" font-size="12" font-family="sans-serif">Output Projection</text>
<line x1="140" y1="283" x2="140" y2="294" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- Add 1 -->
<circle cx="140" cy="306" r="12" fill="#222" stroke="#999" stroke-width="1.5"/>
<text x="140" y="311" text-anchor="middle" fill="#ccc" font-size="15" font-family="sans-serif" font-weight="bold">+</text>
<line x1="140" y1="318" x2="140" y2="342" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- Residual 2 -->
<line x1="140" y1="330" x2="280" y2="330" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/>
<line x1="280" y1="330" x2="280" y2="586" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/>
<line x1="280" y1="586" x2="157" y2="586" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4" marker-end="url(#ah)"/>
<text x="290" y="460" fill="#666" font-size="10" font-family="sans-serif" transform="rotate(90,290,460)">residual</text>

<!-- RMSNorm2 -->
<rect x="60" y="345" width="160" height="30" rx="5" fill="#333" stroke="#777" stroke-width="1"/>
<text x="140" y="365" text-anchor="middle" fill="#ccc" font-size="12" font-family="sans-serif">RMSNorm 2</text>
<line x1="140" y1="375" x2="140" y2="393" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- Gate + Up Proj -->
<rect x="60" y="396" width="160" height="30" rx="5" fill="#1e3d2d" stroke="#44aa66" stroke-width="1"/>
<text x="140" y="416" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="sans-serif">Gate &amp; Up Proj (512→1408)</text>
<line x1="140" y1="426" x2="140" y2="444" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- SiLU + multiply -->
<rect x="60" y="447" width="160" height="30" rx="5" fill="#1e3d2d" stroke="#44aa66" stroke-width="1"/>
<text x="140" y="467" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="sans-serif">SiLU(gate) &#x2299; up</text>
<line x1="140" y1="477" x2="140" y2="495" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- Down Proj -->
<rect x="60" y="498" width="160" height="30" rx="5" fill="#1e3d2d" stroke="#44aa66" stroke-width="1"/>
<text x="140" y="518" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="sans-serif">Down Proj (1408→512)</text>
<line x1="140" y1="528" x2="140" y2="574" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- Add 2 -->
<circle cx="140" cy="586" r="12" fill="#222" stroke="#999" stroke-width="1.5"/>
<text x="140" y="591" text-anchor="middle" fill="#ccc" font-size="15" font-family="sans-serif" font-weight="bold">+</text>
<line x1="140" y1="598" x2="140" y2="622" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- Output label -->
<text x="140" y="640" text-anchor="middle" fill="#ccc" font-size="13" font-family="monospace">output (seq_len, 512)</text>
</svg>

<p>
The block follows Llama's <strong>pre-norm</strong> architecture. Unlike GPT-2, it uses
<strong>RMSNorm</strong> (no mean subtraction, no additive bias), <strong>Grouped Query
Attention</strong> with 8 query heads and 2 key/value heads, <strong>Rotary Position
Embeddings</strong> applied to Q and K, and a <strong>SwiGLU</strong> feed-forward network.
None of the linear projections have bias terms.
</p>

\[
\begin{aligned}
x' &= x + \text{Attn}\!\left(\text{RMSNorm}_1(x),\; \cos,\; \sin\right) \\[4pt]
\text{output} &= x' + \text{FFN}\!\left(\text{RMSNorm}_2(x')\right)
\end{aligned}
\]

<p>The sub-operations in detail:</p>

\[
\begin{aligned}
\text{RMSNorm}(z, w) &= \frac{z}{\sqrt{\frac{1}{d}\sum_i z_i^2 + \varepsilon}} \odot w, \quad \varepsilon = 10^{-5} \\[8pt]
Q &= \text{RMSNorm}_1(x)\, W_Q^\top \in \mathbb{R}^{T \times 512}, \quad \text{reshape to } (T, 8, 64) \\[4pt]
K &= \text{RMSNorm}_1(x)\, W_K^\top \in \mathbb{R}^{T \times 128}, \quad \text{reshape to } (T, 2, 64) \\[4pt]
V &= \text{RMSNorm}_1(x)\, W_V^\top \in \mathbb{R}^{T \times 128}, \quad \text{reshape to } (T, 2, 64) \\[8pt]
\text{RoPE}(q, \cos, \sin) &: \quad [q_1 \mid q_2] \mapsto [q_1 \odot \cos - q_2 \odot \sin \mid q_1 \odot \sin + q_2 \odot \cos] \\[4pt]
&\quad q_1 = q[\ldots, {:}32],\; q_2 = q[\ldots, {32:}] \\[8pt]
\text{GQA} &: \text{repeat } K,V \text{ along head dim } 4\times \text{ to match 8 Q heads} \\[4pt]
\text{head}_i &= \text{softmax}\!\left(\frac{Q_i K_i^\top}{\sqrt{64}} + M_{\text{causal}}\right) V_i \\[8pt]
\text{Attn}(x) &= \text{Concat}(\text{head}_1, \ldots, \text{head}_8)\; W_O^\top \\[8pt]
\text{FFN}(z) &= \bigl(\text{SiLU}(z\, W_{\text{gate}}^\top) \odot z\, W_{\text{up}}^\top\bigr)\; W_{\text{down}}^\top
\end{aligned}
\]

<p>where \(M_{\text{causal}}\) is the upper-triangular causal mask (\(-\infty\) above the diagonal)
and \(\text{SiLU}(x) = x \cdot \sigma(x)\).</p>

<h2>Implementation Requirements</h2>
<ul>
<li>Use only native features (external libraries are not permitted)</li>
<li>The <code>solve</code> function signature must remain unchanged</li>
<li>The final result must be stored in the <code>output</code> tensor</li>
<li>RMSNorm uses \(\varepsilon = 10^{-5}\), no additive bias</li>
<li>Apply causal masking: position \(i\) attends only to positions \(\le i\)</li>
<li>Repeat K and V heads \(4\times\) (GQA groups) before computing attention</li>
<li><code>cos</code> and <code>sin</code> have shape <code>(seq_len, 32)</code> — apply
them to both Q and K heads independently</li>
</ul>

<h2>Weight Layout</h2>
<p>All parameters are packed into a single contiguous <code>weights</code> buffer
(2,819,072 floats) in the order below. All 2-D matrices are stored row-major
with shape <code>(out_dim, in_dim)</code>. There are no bias terms.</p>

<table style="border-collapse:separate; border-spacing:16px 6px;">
<tr>
<th style="text-align:left;">Parameter</th>
<th style="text-align:left;">Shape</th>
<th style="text-align:right;">Size</th>
<th style="text-align:right;">Offset</th>
</tr>
<tr>
<td>\(w_1\) (RMSNorm 1 scale)</td>
<td>(512,)</td>
<td style="text-align:right;">512</td>
<td style="text-align:right;">0</td>
</tr>
<tr>
<td>\(W_Q\)</td>
<td>(512, 512)</td>
<td style="text-align:right;">262,144</td>
<td style="text-align:right;">512</td>
</tr>
<tr>
<td>\(W_K\)</td>
<td>(128, 512)</td>
<td style="text-align:right;">65,536</td>
<td style="text-align:right;">262,656</td>
</tr>
<tr>
<td>\(W_V\)</td>
<td>(128, 512)</td>
<td style="text-align:right;">65,536</td>
<td style="text-align:right;">328,192</td>
</tr>
<tr>
<td>\(W_O\)</td>
<td>(512, 512)</td>
<td style="text-align:right;">262,144</td>
<td style="text-align:right;">393,728</td>
</tr>
<tr>
<td>\(w_2\) (RMSNorm 2 scale)</td>
<td>(512,)</td>
<td style="text-align:right;">512</td>
<td style="text-align:right;">655,872</td>
</tr>
<tr>
<td>\(W_{\text{gate}}\)</td>
<td>(1408, 512)</td>
<td style="text-align:right;">720,896</td>
<td style="text-align:right;">656,384</td>
</tr>
<tr>
<td>\(W_{\text{up}}\)</td>
<td>(1408, 512)</td>
<td style="text-align:right;">720,896</td>
<td style="text-align:right;">1,377,280</td>
</tr>
<tr>
<td>\(W_{\text{down}}\)</td>
<td>(512, 1408)</td>
<td style="text-align:right;">720,896</td>
<td style="text-align:right;">2,098,176</td>
</tr>
</table>

<h2>Example</h2>
<p>With <code>seq_len</code> = 4, <code>x</code> drawn uniformly from [&minus;1, 1], and randomly
initialized weights:</p>
<pre>
Input: x.shape = (4, 512) # 4 token hidden states
weights.shape = (2,819,072,) # packed weight buffer
cos.shape = (4, 32) # precomputed RoPE cosines
sin.shape = (4, 32) # precomputed RoPE sines
seq_len = 4
Output: output.shape = (4, 512) # transformed token hidden states
</pre>

<h2>Constraints</h2>
<ul>
<li><code>d_model</code> = 512, <code>n_q_heads</code> = 8, <code>n_kv_heads</code> = 2,
<code>head_dim</code> = 64, <code>ffn_hidden</code> = 1,408</li>
<li>1 &le; <code>seq_len</code> &le; 4,096</li>
<li>All tensors use 32-bit floating point</li>
<li>Performance is measured with <code>seq_len</code> = 2,048</li>
</ul>
Loading
Loading