Skip to content

Commit 8b125d8

Browse files
Clean up spec
1 parent bad839a commit 8b125d8

1 file changed

Lines changed: 14 additions & 126 deletions

File tree

challenges/hard/83_turboquant_attention/challenge.html

Lines changed: 14 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -6,122 +6,8 @@
66
compressed keys and compute dot-product attention scores against full-precision queries.
77
</p>
88

9-
<svg width="660" height="300" viewBox="0 0 660 300" xmlns="http://www.w3.org/2000/svg"
10-
style="display:block; margin:20px auto; font-family:monospace;">
11-
<rect width="660" height="300" fill="#222" rx="8"/>
12-
13-
<!-- Title -->
14-
<text x="330" y="20" text-anchor="middle" fill="#888" font-size="10">TurboQuant Dequantization Pipeline (per key vector)</text>
15-
16-
<!-- ============================================================ -->
17-
<!-- Stage 1: MSE dequantization (top row) -->
18-
<!-- ============================================================ -->
19-
<text x="16" y="50" fill="#4a9edd" font-size="10" font-weight="bold">Stage 1: MSE</text>
20-
21-
<!-- K_idx -->
22-
<rect x="16" y="60" width="80" height="40" rx="4" fill="#1a3a5c" stroke="#4a9edd" stroke-width="1.5"/>
23-
<text x="56" y="78" text-anchor="middle" fill="#4a9edd" font-size="10">K_idx</text>
24-
<text x="56" y="92" text-anchor="middle" fill="#7799bb" font-size="8">[S, D] uint8</text>
25-
26-
<!-- arrow -->
27-
<text x="110" y="84" fill="#888" font-size="11" text-anchor="middle">&#x2192;</text>
28-
29-
<!-- codebook lookup -->
30-
<rect x="124" y="60" width="120" height="40" rx="4" fill="#1a1a1a" stroke="#666" stroke-width="1"/>
31-
<text x="184" y="78" text-anchor="middle" fill="#ccc" font-size="9">codebook[ K_idx ]</text>
32-
<text x="184" y="92" text-anchor="middle" fill="#888" font-size="8">centroid lookup</text>
33-
34-
<!-- arrow -->
35-
<text x="258" y="84" fill="#888" font-size="11" text-anchor="middle">&#x2192;</text>
36-
37-
<!-- multiply by Pi -->
38-
<rect x="272" y="60" width="80" height="40" rx="4" fill="#1a1a1a" stroke="#666" stroke-width="1"/>
39-
<text x="312" y="78" text-anchor="middle" fill="#ccc" font-size="9">&#xd7; &#x3a0;</text>
40-
<text x="312" y="92" text-anchor="middle" fill="#888" font-size="8">rotate back</text>
41-
42-
<!-- = -->
43-
<text x="366" y="84" fill="#888" font-size="11" text-anchor="middle">=</text>
44-
45-
<!-- K_mse result -->
46-
<rect x="380" y="60" width="100" height="40" rx="4" fill="#1a3a5c" stroke="#4a9edd" stroke-width="1.5"/>
47-
<text x="430" y="82" text-anchor="middle" fill="#4a9edd" font-size="10">K&#x303;_mse</text>
48-
<text x="430" y="94" text-anchor="middle" fill="#7799bb" font-size="8">[S, D] float32</text>
49-
50-
<!-- ============================================================ -->
51-
<!-- Stage 2: QJL dequantization (middle row) -->
52-
<!-- ============================================================ -->
53-
<text x="16" y="130" fill="#52b788" font-size="10" font-weight="bold">Stage 2: QJL residual</text>
54-
55-
<!-- sigma (signs) -->
56-
<rect x="16" y="140" width="60" height="40" rx="4" fill="#1a3a2a" stroke="#52b788" stroke-width="1.5"/>
57-
<text x="46" y="158" text-anchor="middle" fill="#52b788" font-size="10">&#x3c3;</text>
58-
<text x="46" y="172" text-anchor="middle" fill="#77bbaa" font-size="8">[S,D] &#xb1;1</text>
59-
60-
<!-- multiply by S_proj -->
61-
<text x="90" y="164" fill="#888" font-size="11" text-anchor="middle">&#x2192;</text>
62-
63-
<rect x="104" y="140" width="60" height="40" rx="4" fill="#1a1a1a" stroke="#666" stroke-width="1"/>
64-
<text x="134" y="158" text-anchor="middle" fill="#ccc" font-size="9">&#x3c3; &#xb7; M</text>
65-
<text x="134" y="172" text-anchor="middle" fill="#888" font-size="8">project</text>
66-
67-
<!-- multiply by scale * gamma -->
68-
<text x="178" y="164" fill="#888" font-size="11" text-anchor="middle">&#x2192;</text>
69-
70-
<rect x="192" y="140" width="120" height="40" rx="4" fill="#1a1a1a" stroke="#666" stroke-width="1"/>
71-
<text x="252" y="158" text-anchor="middle" fill="#ccc" font-size="9">&#xd7; &#x221a;(&#x3c0;/2)/D &#xd7; &#x3b3;</text>
72-
<text x="252" y="172" text-anchor="middle" fill="#888" font-size="8">scale by norm</text>
73-
74-
<!-- = -->
75-
<text x="326" y="164" fill="#888" font-size="11" text-anchor="middle">=</text>
76-
77-
<!-- K_qjl result -->
78-
<rect x="340" y="140" width="100" height="40" rx="4" fill="#1a3a2a" stroke="#52b788" stroke-width="1.5"/>
79-
<text x="390" y="162" text-anchor="middle" fill="#52b788" font-size="10">K&#x303;_res</text>
80-
<text x="390" y="174" text-anchor="middle" fill="#77bbaa" font-size="8">[S, D] float32</text>
81-
82-
<!-- ============================================================ -->
83-
<!-- Combine + dot product (bottom) -->
84-
<!-- ============================================================ -->
85-
<text x="16" y="210" fill="#e0a040" font-size="10" font-weight="bold">Combine + Score</text>
86-
87-
<!-- K_mse + K_res -->
88-
<rect x="16" y="220" width="100" height="36" rx="4" fill="#1a3a5c" stroke="#4a9edd" stroke-width="1"/>
89-
<text x="66" y="242" text-anchor="middle" fill="#4a9edd" font-size="10">K&#x303;_mse</text>
90-
91-
<text x="128" y="242" fill="#ccc" font-size="14" text-anchor="middle">+</text>
92-
93-
<rect x="142" y="220" width="100" height="36" rx="4" fill="#1a3a2a" stroke="#52b788" stroke-width="1"/>
94-
<text x="192" y="242" text-anchor="middle" fill="#52b788" font-size="10">K&#x303;_res</text>
95-
96-
<text x="256" y="242" fill="#ccc" font-size="14" text-anchor="middle">=</text>
97-
98-
<rect x="270" y="220" width="80" height="36" rx="4" fill="#3a2a1a" stroke="#e0a040" stroke-width="1.5"/>
99-
<text x="310" y="242" text-anchor="middle" fill="#e0a040" font-size="10">K&#x303;</text>
100-
101-
<!-- dot with Q -->
102-
<text x="370" y="242" fill="#ccc" font-size="10" text-anchor="middle">then:</text>
103-
104-
<rect x="394" y="220" width="70" height="36" rx="4" fill="#2a1a3a" stroke="#c060e0" stroke-width="1.5"/>
105-
<text x="429" y="242" text-anchor="middle" fill="#c060e0" font-size="10">Q</text>
106-
107-
<text x="478" y="242" fill="#ccc" font-size="14" text-anchor="middle">&#xb7;</text>
108-
109-
<rect x="492" y="220" width="70" height="36" rx="4" fill="#3a2a1a" stroke="#e0a040" stroke-width="1"/>
110-
<text x="527" y="242" text-anchor="middle" fill="#e0a040" font-size="10">K&#x303;&#x1d40;</text>
111-
112-
<text x="576" y="242" fill="#ccc" font-size="14" text-anchor="middle">=</text>
113-
114-
<rect x="590" y="220" width="56" height="36" rx="4" fill="#3a1a1a" stroke="#e05050" stroke-width="1.5"/>
115-
<text x="618" y="242" text-anchor="middle" fill="#e05050" font-size="10">scores</text>
116-
117-
<!-- Legend -->
118-
<text x="16" y="280" fill="#888" font-size="9">&#x3a0; = orthogonal rotation [D&#xd7;D]</text>
119-
<text x="200" y="280" fill="#888" font-size="9">M = Gaussian projection [D&#xd7;D]</text>
120-
<text x="420" y="280" fill="#888" font-size="9">&#x3c3; = sign bits &#xb1;1, &#x3b3; = &#x2016;residual&#x2016;&#x2082;</text>
121-
</svg>
122-
1239
<p>
124-
<strong>Background how the keys were compressed</strong> (already done for you, not part of the challenge):
10+
<strong>Background - how the keys were compressed</strong> (already done for you, not part of the challenge):
12511
</p>
12612
<ol>
12713
<li><strong>Rotate</strong>: multiply key by orthogonal matrix \(\Pi\): \(\;y = \Pi \cdot K\). This makes each
@@ -131,21 +17,21 @@
13117
<li><strong>Residual correction</strong>: MSE quantization loses information. Compute the residual
13218
\(r = K - \tilde{K}_\text{mse}\), then store:
13319
<ul>
134-
<li>\(\sigma = \text{sign}(M \cdot r) \in \{-1,+1\}^D\) direction (<code>int8</code>)</li>
135-
<li>\(\gamma = \|r\|_2\) magnitude (<code>float32</code> scalar per key)</li>
20+
<li>\(\sigma = \text{sign}(S_\text{mat} \cdot r) \in \{-1,+1\}^D\) - direction (<code>int8</code>)</li>
21+
<li>\(\gamma = \|r\|_2\) - magnitude (<code>float32</code> scalar per key)</li>
13622
</ul>
137-
where \(M \in \mathbb{R}^{D \times D}\) is a random Gaussian projection matrix (<code>S_mat</code> in code).
23+
where \(S_\text{mat} \in \mathbb{R}^{D \times D}\) is a random Gaussian projection matrix.
13824
</li>
13925
</ol>
14026

14127
<p>
142-
<strong>What you compute</strong> dequantize and score:
28+
<strong>What you compute</strong> - dequantize and score:
14329
</p>
14430
<ol>
14531
<li><strong>MSE dequantize</strong>: look up centroids, undo the rotation:
14632
\[\tilde{K}_\text{mse} = \text{codebook}[K_\text{idx}] \cdot \Pi\]</li>
14733
<li><strong>Residual dequantize</strong>: reconstruct the residual correction:
148-
\[\tilde{K}_\text{res} = \frac{\sqrt{\pi/2}}{D} \cdot \gamma \cdot \sigma \cdot M\]
34+
\[\tilde{K}_\text{res} = \frac{\sqrt{\pi/2}}{D} \cdot \gamma \cdot \sigma \cdot S_\text{mat}\]
14935
The \(\sqrt{\pi/2}/D\) constant corrects for the distortion introduced by taking signs.</li>
15036
<li><strong>Combine</strong>:
15137
\(\tilde{K} = \tilde{K}_\text{mse} + \tilde{K}_\text{res}\)</li>
@@ -166,19 +52,20 @@ <h2>Implementation Requirements</h2>
16652

16753
<h2>Example</h2>
16854
<p>
169-
Input: \(B=2,\; S=3,\; D=2,\; C=4\), with \(\Pi = I\), \(M = I\), \(\gamma = \mathbf{0}\) (residual correction disabled):
55+
Input: \(B=2,\; S=3,\; D=2,\; C=4\), with \(\Pi = I\), \(S_\text{mat} = I\), \(\gamma = \mathbf{0}\) (residual correction disabled),
56+
\(\sigma = \mathbf{1}\) (all +1):
17057
</p>
17158
<p>
172-
\(Q = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}\), \quad
173-
\(K_\text{idx} = \begin{bmatrix} 0 & 3 \\ 1 & 2 \\ 3 & 0 \end{bmatrix}\), \quad
59+
\(Q = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}\),
60+
\(K_\text{idx} = \begin{bmatrix} 0 & 3 \\ 1 & 2 \\ 3 & 0 \end{bmatrix}\),
17461
codebook \(= [-0.75,\; -0.25,\; 0.25,\; 0.75]\)
17562
</p>
17663
<p>
177-
Step 1 MSE lookup and rotate back (\(\Pi = I\)):
64+
Step 1 - MSE lookup and rotate back (\(\Pi = I\)):
17865
\[
17966
\tilde{K}_\text{mse} = \begin{bmatrix} -0.75 & 0.75 \\ -0.25 & 0.25 \\ 0.75 & -0.75 \end{bmatrix}
18067
\]
181-
Step 2 Residual correction is zero (\(\gamma = 0\)), so \(\tilde{K} = \tilde{K}_\text{mse}\).
68+
Step 2 - Residual correction is zero (\(\gamma = 0\)), so \(\tilde{K} = \tilde{K}_\text{mse}\).
18269
</p>
18370
<p>
18471
Output:
@@ -194,7 +81,8 @@ <h2>Constraints</h2>
19481
<li>1 &le; <code>D</code> &le; 256</li>
19582
<li>2 &le; <code>C</code> &le; 256</li>
19683
<li>\(\Pi\) is orthogonal (\(\Pi^T \Pi = I\))</li>
197-
<li>\(M\) (<code>S_mat</code>) has i.i.d. \(\mathcal{N}(0,1)\) entries</li>
84+
<li><code>S_mat</code> has i.i.d. \(\mathcal{N}(0,1)\) entries</li>
85+
<li><code>gamma</code> has shape \([S]\) (one \(\ell_2\) norm per key vector, <code>float32</code>)</li>
19886
<li><code>qjl_signs</code> (\(\sigma\)) values are in \(\{-1, +1\}\) (<code>int8</code>)</li>
19987
<li><code>K_idx</code> values are in \([0, C)\) (<code>uint8</code>)</li>
20088
<li>All floating-point inputs are <code>float32</code></li>

0 commit comments

Comments
 (0)