|
| 1 | +<p> |
| 2 | + Implement a single Llama-style transformer decoder block. Given an input tensor \(x\) of shape |
| 3 | + <code>(seq_len, 512)</code>, a packed weight buffer, and precomputed RoPE tables, compute the |
| 4 | + output using pre-norm architecture with Grouped Query Attention (GQA), Rotary Position Embeddings |
| 5 | + (RoPE), and a SwiGLU feed-forward network. |
| 6 | +</p> |
| 7 | + |
| 8 | +<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 340 660" width="340" height="660" style="display:block; margin:20px auto;"> |
| 9 | + <defs> |
| 10 | + <marker id="ah" viewBox="0 0 10 10" refX="9" refY="5" markerWidth="6" markerHeight="6" orient="auto-start-reverse"> |
| 11 | + <path d="M0 0L10 5L0 10z" fill="#999"/> |
| 12 | + </marker> |
| 13 | + </defs> |
| 14 | + <rect width="340" height="660" fill="#222"/> |
| 15 | + |
| 16 | + <!-- Input label --> |
| 17 | + <text x="140" y="20" text-anchor="middle" fill="#ccc" font-size="13" font-family="monospace">x (seq_len, 512)</text> |
| 18 | + |
| 19 | + <!-- Arrow: input -> RMSNorm1 --> |
| 20 | + <line x1="140" y1="28" x2="140" y2="46" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 21 | + |
| 22 | + <!-- Residual 1: fork right, down, back left to Add1 --> |
| 23 | + <line x1="140" y1="36" x2="280" y2="36" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/> |
| 24 | + <line x1="280" y1="36" x2="280" y2="306" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/> |
| 25 | + <line x1="280" y1="306" x2="157" y2="306" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4" marker-end="url(#ah)"/> |
| 26 | + <text x="290" y="175" fill="#666" font-size="10" font-family="sans-serif" transform="rotate(90,290,175)">residual</text> |
| 27 | + |
| 28 | + <!-- RMSNorm1 --> |
| 29 | + <rect x="60" y="49" width="160" height="30" rx="5" fill="#333" stroke="#777" stroke-width="1"/> |
| 30 | + <text x="140" y="69" text-anchor="middle" fill="#ccc" font-size="12" font-family="sans-serif">RMSNorm 1</text> |
| 31 | + <line x1="140" y1="79" x2="140" y2="97" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 32 | + |
| 33 | + <!-- QKV Proj --> |
| 34 | + <rect x="60" y="100" width="160" height="30" rx="5" fill="#1e2d4d" stroke="#4477bb" stroke-width="1"/> |
| 35 | + <text x="140" y="120" text-anchor="middle" fill="#aaccee" font-size="12" font-family="sans-serif">QKV Projection (GQA)</text> |
| 36 | + <line x1="140" y1="130" x2="140" y2="148" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 37 | + |
| 38 | + <!-- RoPE --> |
| 39 | + <rect x="60" y="151" width="160" height="30" rx="5" fill="#2d1e4d" stroke="#7755bb" stroke-width="1"/> |
| 40 | + <text x="140" y="171" text-anchor="middle" fill="#ccaaee" font-size="12" font-family="sans-serif">RoPE (Q and K)</text> |
| 41 | + <line x1="140" y1="181" x2="140" y2="199" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 42 | + |
| 43 | + <!-- Causal Attn --> |
| 44 | + <rect x="60" y="202" width="160" height="30" rx="5" fill="#1e2d4d" stroke="#4477bb" stroke-width="1"/> |
| 45 | + <text x="140" y="222" text-anchor="middle" fill="#aaccee" font-size="12" font-family="sans-serif">Causal Attention</text> |
| 46 | + <line x1="140" y1="232" x2="140" y2="250" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 47 | + |
| 48 | + <!-- Output Proj --> |
| 49 | + <rect x="60" y="253" width="160" height="30" rx="5" fill="#1e2d4d" stroke="#4477bb" stroke-width="1"/> |
| 50 | + <text x="140" y="273" text-anchor="middle" fill="#aaccee" font-size="12" font-family="sans-serif">Output Projection</text> |
| 51 | + <line x1="140" y1="283" x2="140" y2="294" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 52 | + |
| 53 | + <!-- Add 1 --> |
| 54 | + <circle cx="140" cy="306" r="12" fill="#222" stroke="#999" stroke-width="1.5"/> |
| 55 | + <text x="140" y="311" text-anchor="middle" fill="#ccc" font-size="15" font-family="sans-serif" font-weight="bold">+</text> |
| 56 | + <line x1="140" y1="318" x2="140" y2="342" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 57 | + |
| 58 | + <!-- Residual 2 --> |
| 59 | + <line x1="140" y1="330" x2="280" y2="330" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/> |
| 60 | + <line x1="280" y1="330" x2="280" y2="586" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/> |
| 61 | + <line x1="280" y1="586" x2="157" y2="586" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4" marker-end="url(#ah)"/> |
| 62 | + <text x="290" y="460" fill="#666" font-size="10" font-family="sans-serif" transform="rotate(90,290,460)">residual</text> |
| 63 | + |
| 64 | + <!-- RMSNorm2 --> |
| 65 | + <rect x="60" y="345" width="160" height="30" rx="5" fill="#333" stroke="#777" stroke-width="1"/> |
| 66 | + <text x="140" y="365" text-anchor="middle" fill="#ccc" font-size="12" font-family="sans-serif">RMSNorm 2</text> |
| 67 | + <line x1="140" y1="375" x2="140" y2="393" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 68 | + |
| 69 | + <!-- Gate + Up Proj --> |
| 70 | + <rect x="60" y="396" width="160" height="30" rx="5" fill="#1e3d2d" stroke="#44aa66" stroke-width="1"/> |
| 71 | + <text x="140" y="416" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="sans-serif">Gate & Up Proj (512→1408)</text> |
| 72 | + <line x1="140" y1="426" x2="140" y2="444" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 73 | + |
| 74 | + <!-- SiLU + multiply --> |
| 75 | + <rect x="60" y="447" width="160" height="30" rx="5" fill="#1e3d2d" stroke="#44aa66" stroke-width="1"/> |
| 76 | + <text x="140" y="467" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="sans-serif">SiLU(gate) ⊙ up</text> |
| 77 | + <line x1="140" y1="477" x2="140" y2="495" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 78 | + |
| 79 | + <!-- Down Proj --> |
| 80 | + <rect x="60" y="498" width="160" height="30" rx="5" fill="#1e3d2d" stroke="#44aa66" stroke-width="1"/> |
| 81 | + <text x="140" y="518" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="sans-serif">Down Proj (1408→512)</text> |
| 82 | + <line x1="140" y1="528" x2="140" y2="574" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 83 | + |
| 84 | + <!-- Add 2 --> |
| 85 | + <circle cx="140" cy="586" r="12" fill="#222" stroke="#999" stroke-width="1.5"/> |
| 86 | + <text x="140" y="591" text-anchor="middle" fill="#ccc" font-size="15" font-family="sans-serif" font-weight="bold">+</text> |
| 87 | + <line x1="140" y1="598" x2="140" y2="622" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 88 | + |
| 89 | + <!-- Output label --> |
| 90 | + <text x="140" y="640" text-anchor="middle" fill="#ccc" font-size="13" font-family="monospace">output (seq_len, 512)</text> |
| 91 | +</svg> |
| 92 | + |
| 93 | +<p> |
| 94 | + The block follows Llama's <strong>pre-norm</strong> architecture. Unlike GPT-2, it uses |
| 95 | + <strong>RMSNorm</strong> (no mean subtraction, no additive bias), <strong>Grouped Query |
| 96 | + Attention</strong> with 8 query heads and 2 key/value heads, <strong>Rotary Position |
| 97 | + Embeddings</strong> applied to Q and K, and a <strong>SwiGLU</strong> feed-forward network. |
| 98 | + None of the linear projections have bias terms. |
| 99 | +</p> |
| 100 | + |
| 101 | +\[ |
| 102 | +\begin{aligned} |
| 103 | +x' &= x + \text{Attn}\!\left(\text{RMSNorm}_1(x),\; \cos,\; \sin\right) \\[4pt] |
| 104 | +\text{output} &= x' + \text{FFN}\!\left(\text{RMSNorm}_2(x')\right) |
| 105 | +\end{aligned} |
| 106 | +\] |
| 107 | + |
| 108 | +<p>The sub-operations in detail:</p> |
| 109 | + |
| 110 | +\[ |
| 111 | +\begin{aligned} |
| 112 | +\text{RMSNorm}(z, w) &= \frac{z}{\sqrt{\frac{1}{d}\sum_i z_i^2 + \varepsilon}} \odot w, \quad \varepsilon = 10^{-5} \\[8pt] |
| 113 | +Q &= \text{RMSNorm}_1(x)\, W_Q^\top \in \mathbb{R}^{T \times 512}, \quad \text{reshape to } (T, 8, 64) \\[4pt] |
| 114 | +K &= \text{RMSNorm}_1(x)\, W_K^\top \in \mathbb{R}^{T \times 128}, \quad \text{reshape to } (T, 2, 64) \\[4pt] |
| 115 | +V &= \text{RMSNorm}_1(x)\, W_V^\top \in \mathbb{R}^{T \times 128}, \quad \text{reshape to } (T, 2, 64) \\[8pt] |
| 116 | +\text{RoPE}(q, \cos, \sin) &: \quad [q_1 \mid q_2] \mapsto [q_1 \odot \cos - q_2 \odot \sin \mid q_1 \odot \sin + q_2 \odot \cos] \\[4pt] |
| 117 | +&\quad q_1 = q[\ldots, {:}32],\; q_2 = q[\ldots, {32:}] \\[8pt] |
| 118 | +\text{GQA} &: \text{repeat } K,V \text{ along head dim } 4\times \text{ to match 8 Q heads} \\[4pt] |
| 119 | +\text{head}_i &= \text{softmax}\!\left(\frac{Q_i K_i^\top}{\sqrt{64}} + M_{\text{causal}}\right) V_i \\[8pt] |
| 120 | +\text{Attn}(x) &= \text{Concat}(\text{head}_1, \ldots, \text{head}_8)\; W_O^\top \\[8pt] |
| 121 | +\text{FFN}(z) &= \bigl(\text{SiLU}(z\, W_{\text{gate}}^\top) \odot z\, W_{\text{up}}^\top\bigr)\; W_{\text{down}}^\top |
| 122 | +\end{aligned} |
| 123 | +\] |
| 124 | + |
| 125 | +<p>where \(M_{\text{causal}}\) is the upper-triangular causal mask (\(-\infty\) above the diagonal) |
| 126 | +and \(\text{SiLU}(x) = x \cdot \sigma(x)\).</p> |
| 127 | + |
| 128 | +<h2>Implementation Requirements</h2> |
| 129 | +<ul> |
| 130 | + <li>Use only native features (external libraries are not permitted)</li> |
| 131 | + <li>The <code>solve</code> function signature must remain unchanged</li> |
| 132 | + <li>The final result must be stored in the <code>output</code> tensor</li> |
| 133 | + <li>RMSNorm uses \(\varepsilon = 10^{-5}\), no additive bias</li> |
| 134 | + <li>Apply causal masking: position \(i\) attends only to positions \(\le i\)</li> |
| 135 | + <li>Repeat K and V heads \(4\times\) (GQA groups) before computing attention</li> |
| 136 | + <li><code>cos</code> and <code>sin</code> have shape <code>(seq_len, 32)</code> — apply |
| 137 | + them to both Q and K heads independently</li> |
| 138 | +</ul> |
| 139 | + |
| 140 | +<h2>Weight Layout</h2> |
| 141 | +<p>All parameters are packed into a single contiguous <code>weights</code> buffer |
| 142 | +(2,819,072 floats) in the order below. All 2-D matrices are stored row-major |
| 143 | +with shape <code>(out_dim, in_dim)</code>. There are no bias terms.</p> |
| 144 | + |
| 145 | +<table style="border-collapse:separate; border-spacing:16px 6px;"> |
| 146 | + <tr> |
| 147 | + <th style="text-align:left;">Parameter</th> |
| 148 | + <th style="text-align:left;">Shape</th> |
| 149 | + <th style="text-align:right;">Size</th> |
| 150 | + <th style="text-align:right;">Offset</th> |
| 151 | + </tr> |
| 152 | + <tr> |
| 153 | + <td>\(w_1\) (RMSNorm 1 scale)</td> |
| 154 | + <td>(512,)</td> |
| 155 | + <td style="text-align:right;">512</td> |
| 156 | + <td style="text-align:right;">0</td> |
| 157 | + </tr> |
| 158 | + <tr> |
| 159 | + <td>\(W_Q\)</td> |
| 160 | + <td>(512, 512)</td> |
| 161 | + <td style="text-align:right;">262,144</td> |
| 162 | + <td style="text-align:right;">512</td> |
| 163 | + </tr> |
| 164 | + <tr> |
| 165 | + <td>\(W_K\)</td> |
| 166 | + <td>(128, 512)</td> |
| 167 | + <td style="text-align:right;">65,536</td> |
| 168 | + <td style="text-align:right;">262,656</td> |
| 169 | + </tr> |
| 170 | + <tr> |
| 171 | + <td>\(W_V\)</td> |
| 172 | + <td>(128, 512)</td> |
| 173 | + <td style="text-align:right;">65,536</td> |
| 174 | + <td style="text-align:right;">328,192</td> |
| 175 | + </tr> |
| 176 | + <tr> |
| 177 | + <td>\(W_O\)</td> |
| 178 | + <td>(512, 512)</td> |
| 179 | + <td style="text-align:right;">262,144</td> |
| 180 | + <td style="text-align:right;">393,728</td> |
| 181 | + </tr> |
| 182 | + <tr> |
| 183 | + <td>\(w_2\) (RMSNorm 2 scale)</td> |
| 184 | + <td>(512,)</td> |
| 185 | + <td style="text-align:right;">512</td> |
| 186 | + <td style="text-align:right;">655,872</td> |
| 187 | + </tr> |
| 188 | + <tr> |
| 189 | + <td>\(W_{\text{gate}}\)</td> |
| 190 | + <td>(1408, 512)</td> |
| 191 | + <td style="text-align:right;">720,896</td> |
| 192 | + <td style="text-align:right;">656,384</td> |
| 193 | + </tr> |
| 194 | + <tr> |
| 195 | + <td>\(W_{\text{up}}\)</td> |
| 196 | + <td>(1408, 512)</td> |
| 197 | + <td style="text-align:right;">720,896</td> |
| 198 | + <td style="text-align:right;">1,377,280</td> |
| 199 | + </tr> |
| 200 | + <tr> |
| 201 | + <td>\(W_{\text{down}}\)</td> |
| 202 | + <td>(512, 1408)</td> |
| 203 | + <td style="text-align:right;">720,896</td> |
| 204 | + <td style="text-align:right;">2,098,176</td> |
| 205 | + </tr> |
| 206 | +</table> |
| 207 | + |
| 208 | +<h2>Example</h2> |
| 209 | +<p>With <code>seq_len</code> = 4, <code>x</code> drawn uniformly from [−1, 1], and randomly |
| 210 | +initialized weights:</p> |
| 211 | +<pre> |
| 212 | +Input: x.shape = (4, 512) # 4 token hidden states |
| 213 | + weights.shape = (2,819,072,) # packed weight buffer |
| 214 | + cos.shape = (4, 32) # precomputed RoPE cosines |
| 215 | + sin.shape = (4, 32) # precomputed RoPE sines |
| 216 | + seq_len = 4 |
| 217 | +Output: output.shape = (4, 512) # transformed token hidden states |
| 218 | +</pre> |
| 219 | + |
| 220 | +<h2>Constraints</h2> |
| 221 | +<ul> |
| 222 | + <li><code>d_model</code> = 512, <code>n_q_heads</code> = 8, <code>n_kv_heads</code> = 2, |
| 223 | + <code>head_dim</code> = 64, <code>ffn_hidden</code> = 1,408</li> |
| 224 | + <li>1 ≤ <code>seq_len</code> ≤ 4,096</li> |
| 225 | + <li>All tensors use 32-bit floating point</li> |
| 226 | + <li>Performance is measured with <code>seq_len</code> = 2,048</li> |
| 227 | +</ul> |
0 commit comments