|
| 1 | +<p> |
| 2 | + Implement decaying causal attention. Given query matrix <code>Q</code>, key matrix <code>K</code>, |
| 3 | + and value matrix <code>V</code>, each of shape <code>seq_len × d_model</code>, and a scalar |
| 4 | + decay factor <code>gamma</code> ∈ (0, 1], compute the unnormalized causal attention output |
| 5 | + where position <code>n</code> attends to all past positions <code>m ≤ n</code> with weight |
| 6 | + <code>gamma<sup>n−m</sup></code>: |
| 7 | +</p> |
| 8 | +<p> |
| 9 | + \[ |
| 10 | + \text{output}[n] = \sum_{m=0}^{n} \gamma^{n-m} \cdot \frac{Q[n] \cdot K[m]}{\sqrt{d_{\text{model}}}} \cdot V[m] |
| 11 | + \] |
| 12 | +</p> |
| 13 | +<p> |
| 14 | + Unlike standard softmax attention, there is no normalization — the weights decay geometrically from |
| 15 | + the current position backward. This is the parallel form of the Retention mechanism (RetNet), used |
| 16 | + as a recurrence-friendly alternative to attention in sequence models. |
| 17 | +</p> |
| 18 | + |
| 19 | +<svg width="680" height="215" viewBox="0 0 680 215" xmlns="http://www.w3.org/2000/svg" |
| 20 | + style="display:block; margin:20px auto;"> |
| 21 | + <rect width="680" height="215" fill="#222" rx="8"/> |
| 22 | + |
| 23 | + <!-- Section title: decay mask --> |
| 24 | + <text x="148" y="24" text-anchor="middle" fill="#ccc" font-size="12" font-family="monospace">Causal Decay Mask D[n,m] = γ^(n−m)</text> |
| 25 | + |
| 26 | + <!-- Column headers m=0..3 --> |
| 27 | + <text x="80" y="43" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">m=0</text> |
| 28 | + <text x="125" y="43" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">m=1</text> |
| 29 | + <text x="170" y="43" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">m=2</text> |
| 30 | + <text x="215" y="43" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">m=3</text> |
| 31 | + |
| 32 | + <!-- Row labels n=0..3 --> |
| 33 | + <text x="42" y="72" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">n=0</text> |
| 34 | + <text x="42" y="112" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">n=1</text> |
| 35 | + <text x="42" y="152" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">n=2</text> |
| 36 | + <text x="42" y="192" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">n=3</text> |
| 37 | + |
| 38 | + <!-- Row 0 --> |
| 39 | + <rect x="58" y="53" width="44" height="36" fill="#1a4a8a" stroke="#333" stroke-width="1"/> |
| 40 | + <text x="80" y="76" text-anchor="middle" fill="#4a9eff" font-size="11" font-family="monospace">1</text> |
| 41 | + <rect x="103" y="53" width="44" height="36" fill="#161616" stroke="#333" stroke-width="1"/> |
| 42 | + <rect x="148" y="53" width="44" height="36" fill="#161616" stroke="#333" stroke-width="1"/> |
| 43 | + <rect x="193" y="53" width="44" height="36" fill="#161616" stroke="#333" stroke-width="1"/> |
| 44 | + |
| 45 | + <!-- Row 1 --> |
| 46 | + <rect x="58" y="91" width="44" height="36" fill="#143a6a" stroke="#333" stroke-width="1"/> |
| 47 | + <text x="80" y="114" text-anchor="middle" fill="#3a8ee0" font-size="11" font-family="monospace">γ</text> |
| 48 | + <rect x="103" y="91" width="44" height="36" fill="#1a4a8a" stroke="#333" stroke-width="1"/> |
| 49 | + <text x="125" y="114" text-anchor="middle" fill="#4a9eff" font-size="11" font-family="monospace">1</text> |
| 50 | + <rect x="148" y="91" width="44" height="36" fill="#161616" stroke="#333" stroke-width="1"/> |
| 51 | + <rect x="193" y="91" width="44" height="36" fill="#161616" stroke="#333" stroke-width="1"/> |
| 52 | + |
| 53 | + <!-- Row 2 --> |
| 54 | + <rect x="58" y="129" width="44" height="36" fill="#0e2a54" stroke="#333" stroke-width="1"/> |
| 55 | + <text x="80" y="152" text-anchor="middle" fill="#2a7ec0" font-size="11" font-family="monospace">γ²</text> |
| 56 | + <rect x="103" y="129" width="44" height="36" fill="#143a6a" stroke="#333" stroke-width="1"/> |
| 57 | + <text x="125" y="152" text-anchor="middle" fill="#3a8ee0" font-size="11" font-family="monospace">γ</text> |
| 58 | + <rect x="148" y="129" width="44" height="36" fill="#1a4a8a" stroke="#333" stroke-width="1"/> |
| 59 | + <text x="170" y="152" text-anchor="middle" fill="#4a9eff" font-size="11" font-family="monospace">1</text> |
| 60 | + <rect x="193" y="129" width="44" height="36" fill="#161616" stroke="#333" stroke-width="1"/> |
| 61 | + |
| 62 | + <!-- Row 3 --> |
| 63 | + <rect x="58" y="167" width="44" height="36" fill="#081e3c" stroke="#333" stroke-width="1"/> |
| 64 | + <text x="80" y="190" text-anchor="middle" fill="#1a6ea0" font-size="11" font-family="monospace">γ³</text> |
| 65 | + <rect x="103" y="167" width="44" height="36" fill="#0e2a54" stroke="#333" stroke-width="1"/> |
| 66 | + <text x="125" y="190" text-anchor="middle" fill="#2a7ec0" font-size="11" font-family="monospace">γ²</text> |
| 67 | + <rect x="148" y="167" width="44" height="36" fill="#143a6a" stroke="#333" stroke-width="1"/> |
| 68 | + <text x="170" y="190" text-anchor="middle" fill="#3a8ee0" font-size="11" font-family="monospace">γ</text> |
| 69 | + <rect x="193" y="167" width="44" height="36" fill="#1a4a8a" stroke="#333" stroke-width="1"/> |
| 70 | + <text x="215" y="190" text-anchor="middle" fill="#4a9eff" font-size="11" font-family="monospace">1</text> |
| 71 | + |
| 72 | + <!-- Divider --> |
| 73 | + <line x1="265" y1="30" x2="265" y2="210" stroke="#444" stroke-width="1" stroke-dasharray="4,3"/> |
| 74 | + |
| 75 | + <!-- Right side: computation flow --> |
| 76 | + <text x="472" y="24" text-anchor="middle" fill="#ccc" font-size="12" font-family="monospace">Computation</text> |
| 77 | + |
| 78 | + <defs> |
| 79 | + <marker id="arr2" markerWidth="7" markerHeight="7" refX="5" refY="3" orient="auto"> |
| 80 | + <path d="M0,0 L0,6 L7,3 Z" fill="#888"/> |
| 81 | + </marker> |
| 82 | + </defs> |
| 83 | + |
| 84 | + <!-- Step boxes --> |
| 85 | + <rect x="280" y="48" width="100" height="32" rx="4" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.2"/> |
| 86 | + <text x="330" y="62" text-anchor="middle" fill="#ccc" font-size="10" font-family="monospace">Q [S, D]</text> |
| 87 | + <text x="330" y="74" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">query</text> |
| 88 | + |
| 89 | + <rect x="280" y="90" width="100" height="32" rx="4" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.2"/> |
| 90 | + <text x="330" y="104" text-anchor="middle" fill="#ccc" font-size="10" font-family="monospace">K [S, D]</text> |
| 91 | + <text x="330" y="116" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">key</text> |
| 92 | + |
| 93 | + <rect x="280" y="132" width="100" height="32" rx="4" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.2"/> |
| 94 | + <text x="330" y="146" text-anchor="middle" fill="#ccc" font-size="10" font-family="monospace">V [S, D]</text> |
| 95 | + <text x="330" y="158" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">value</text> |
| 96 | + |
| 97 | + <!-- Arrow from Q and K to scores --> |
| 98 | + <line x1="380" y1="64" x2="412" y2="90" stroke="#888" stroke-width="1.2" marker-end="url(#arr2)"/> |
| 99 | + <line x1="380" y1="106" x2="412" y2="96" stroke="#888" stroke-width="1.2" marker-end="url(#arr2)"/> |
| 100 | + |
| 101 | + <rect x="414" y="78" width="110" height="34" rx="4" fill="#1a2a3c" stroke="#7ec8a0" stroke-width="1.2"/> |
| 102 | + <text x="469" y="92" text-anchor="middle" fill="#7ec8a0" font-size="10" font-family="monospace">QKᵀ / √D</text> |
| 103 | + <text x="469" y="105" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">attn scores [S,S]</text> |
| 104 | + |
| 105 | + <!-- Arrow: multiply by decay mask --> |
| 106 | + <line x1="469" y1="112" x2="469" y2="128" stroke="#888" stroke-width="1.2" marker-end="url(#arr2)"/> |
| 107 | + <text x="505" y="124" fill="#cc88ff" font-size="9" font-family="monospace">⊙ decay mask</text> |
| 108 | + |
| 109 | + <rect x="414" y="130" width="110" height="34" rx="4" fill="#2a1a3c" stroke="#cc88ff" stroke-width="1.2"/> |
| 110 | + <text x="469" y="144" text-anchor="middle" fill="#cc88ff" font-size="10" font-family="monospace">weighted [S,S]</text> |
| 111 | + <text x="469" y="157" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">lower triangular</text> |
| 112 | + |
| 113 | + <!-- Arrow from V and weighted to output --> |
| 114 | + <line x1="380" y1="148" x2="412" y2="148" stroke="#888" stroke-width="1.2" marker-end="url(#arr2)"/> |
| 115 | + <line x1="524" y1="147" x2="546" y2="147" stroke="#888" stroke-width="1.2" marker-end="url(#arr2)"/> |
| 116 | + <text x="535" y="140" fill="#888" font-size="9" font-family="monospace">@</text> |
| 117 | + |
| 118 | + <rect x="548" y="131" width="110" height="34" rx="4" fill="#1a3a1c" stroke="#4aff88" stroke-width="1.2"/> |
| 119 | + <text x="603" y="145" text-anchor="middle" fill="#4aff88" font-size="10" font-family="monospace">output [S, D]</text> |
| 120 | + <text x="603" y="158" text-anchor="middle" fill="#888" font-size="9" font-family="monospace">weighted @ V</text> |
| 121 | +</svg> |
| 122 | + |
| 123 | +<h2>Implementation Requirements</h2> |
| 124 | +<ul> |
| 125 | + <li>Implement the <code>solve</code> function; do not change its signature.</li> |
| 126 | + <li>Do not use external libraries beyond those provided.</li> |
| 127 | + <li>Write the result into <code>output</code>.</li> |
| 128 | +</ul> |
| 129 | + |
| 130 | +<h2>Example</h2> |
| 131 | +<p>Example 1 — with <code>seq_len</code> = 2, <code>d_model</code> = 4, <code>gamma</code> = 0.5:</p> |
| 132 | +<p> |
| 133 | +\[ |
| 134 | +Q = \begin{bmatrix} 1 & 1 & 0 & 0 \\ 1 & 1 & 0 & 0 \end{bmatrix}, \quad |
| 135 | +K = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \end{bmatrix}, \quad |
| 136 | +V = \begin{bmatrix} 4 & 8 & 12 & 16 \\ 4 & 8 & 12 & 16 \end{bmatrix} |
| 137 | +\] |
| 138 | +</p> |
| 139 | +<p> |
| 140 | + Attention scores \(QK^\top / \sqrt{4}\): |
| 141 | + \[ |
| 142 | + A = \begin{bmatrix} 0.5 & 0.5 \\ 0.5 & 0.5 \end{bmatrix} |
| 143 | + \] |
| 144 | + Causal decay mask \(D_{nm} = 0.5^{n-m}\) for \(n \ge m\), else \(0\): |
| 145 | + \[ |
| 146 | + D = \begin{bmatrix} 1 & 0 \\ 0.5 & 1 \end{bmatrix} |
| 147 | + \] |
| 148 | + Weighted attention \(A \odot D\): |
| 149 | + \[ |
| 150 | + \begin{bmatrix} 0.5 & 0 \\ 0.25 & 0.5 \end{bmatrix} |
| 151 | + \] |
| 152 | + Output \((A \odot D)\,V\): |
| 153 | + \[ |
| 154 | + \text{output} = \begin{bmatrix} 2 & 4 & 6 & 8 \\ 3 & 6 & 9 & 12 \end{bmatrix} |
| 155 | + \] |
| 156 | +</p> |
| 157 | + |
| 158 | +<h2>Constraints</h2> |
| 159 | +<ul> |
| 160 | + <li>1 ≤ <code>seq_len</code> ≤ 8,192</li> |
| 161 | + <li>1 ≤ <code>d_model</code> ≤ 256</li> |
| 162 | + <li>0 < <code>gamma</code> ≤ 1</li> |
| 163 | + <li>All tensors are <code>float32</code> on GPU.</li> |
| 164 | + <li>Performance is measured with <code>seq_len</code> = 4,096, <code>d_model</code> = 64</li> |
| 165 | +</ul> |
0 commit comments