-
Notifications
You must be signed in to change notification settings - Fork 80
Expand file tree
/
Copy pathchallenge.html
More file actions
90 lines (85 loc) · 4.08 KB
/
challenge.html
File metadata and controls
90 lines (85 loc) · 4.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
<p>
Implement attention score computation against a
<a href="https://arxiv.org/abs/2504.19874" target="_blank" style="color:#4a9eff; text-decoration:underline;">TurboQuant</a>-compressed
KV cache. TurboQuant compresses each key vector to <code>uint8</code> codebook indices plus a 1-bit
residual correction (QJL), reducing KV cache memory by up to 6x. Your task: dequantize the
compressed keys and compute dot-product attention scores against full-precision queries.
</p>
<p>
<strong>Background - how the keys were compressed</strong> (Not part of the challenge):
</p>
<ol>
<li><strong>Rotate</strong>: multiply key by orthogonal matrix \(\Pi\): \(\;y = \Pi \cdot K\). This makes each
coordinate follow a Beta distribution, so a single fixed codebook works for all coordinates.</li>
<li><strong>Scalar quantize</strong>: replace each coordinate of \(y\) with the index of its nearest
codebook centroid \(\rightarrow K_\text{idx}\) (<code>uint8</code>).</li>
<li><strong>Residual correction</strong>: MSE quantization loses information. Compute the residual
\(r = K - \tilde{K}_\text{mse}\), then store:
<ul>
<li>\(\sigma = \text{sign}(S_\text{mat} \cdot r) \in \{-1,+1\}^D\) - direction (<code>int8</code>)</li>
<li>\(\gamma = \|r\|_2\) - magnitude (<code>float32</code> scalar per key)</li>
</ul>
where \(S_\text{mat} \in \mathbb{R}^{D \times D}\) is a random Gaussian projection matrix.
</li>
</ol>
<p>
<strong>What you compute</strong> - dequantize and score:
</p>
<ol>
<li><strong>MSE dequantize</strong>: look up centroids, undo the rotation:
\[\tilde{K}_\text{mse} = \text{codebook}[K_\text{idx}] \cdot \Pi\]</li>
<li><strong>Residual dequantize</strong>: reconstruct the residual correction:
\[\tilde{K}_\text{res} = \frac{\sqrt{\pi/2}}{D} \cdot \gamma \cdot \sigma \cdot S_\text{mat}\]
The \(\sqrt{\pi/2}/D\) constant corrects for the distortion introduced by taking signs.</li>
<li><strong>Combine</strong>:
\(\tilde{K} = \tilde{K}_\text{mse} + \tilde{K}_\text{res}\)</li>
<li><strong>Dot product</strong>:
\(\text{scores}_{b,s} = Q_b \cdot \tilde{K}_s\)</li>
</ol>
<p>
The residual correction makes the inner product <strong>unbiased</strong>:
\(\mathbb{E}[\langle Q, \tilde{K} \rangle] = \langle Q, K \rangle\).
</p>
<h2>Implementation Requirements</h2>
<ul>
<li>The <code>solve</code> function signature must remain unchanged.</li>
<li>Use only native features (no external libraries).</li>
<li>Store the result in <code>scores</code> as <code>float32</code>.</li>
</ul>
<h2>Example</h2>
<p>
Input: \(B=2,\; S=3,\; D=2,\; C=4\), with \(\Pi = I\), \(S_\text{mat} = I\), \(\gamma = \mathbf{0}\) (residual correction disabled),
\(\sigma = \mathbf{1}\) (all +1):
</p>
<p>
\(Q = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}\),
\(K_\text{idx} = \begin{bmatrix} 0 & 3 \\ 1 & 2 \\ 3 & 0 \end{bmatrix}\),
codebook \(= [-0.75,\; -0.25,\; 0.25,\; 0.75]\)
</p>
<p>
Step 1 - MSE lookup and rotate back (\(\Pi = I\)):
\[
\tilde{K}_\text{mse} = \begin{bmatrix} -0.75 & 0.75 \\ -0.25 & 0.25 \\ 0.75 & -0.75 \end{bmatrix}
\]
Step 2 - Residual correction is zero (\(\gamma = 0\)), so \(\tilde{K} = \tilde{K}_\text{mse}\).
</p>
<p>
Output:
\[
\text{scores} = Q \cdot \tilde{K}^T = \begin{bmatrix} -0.75 & -0.25 & 0.75 \\ 0.75 & 0.25 & -0.75 \end{bmatrix}
\]
</p>
<h2>Constraints</h2>
<ul>
<li>1 ≤ <code>B</code> ≤ 32</li>
<li>1 ≤ <code>S</code> ≤ 65,536</li>
<li>1 ≤ <code>D</code> ≤ 256</li>
<li>2 ≤ <code>C</code> ≤ 256</li>
<li>\(\Pi\) is orthogonal (\(\Pi^T \Pi = I\))</li>
<li><code>S_mat</code> has i.i.d. \(\mathcal{N}(0,1)\) entries</li>
<li><code>gamma</code> has shape \([S]\) (one \(\ell_2\) norm per key vector, <code>float32</code>)</li>
<li><code>qjl_signs</code> (\(\sigma\)) values are in \(\{-1, +1\}\) (<code>int8</code>)</li>
<li><code>K_idx</code> values are in \([0, C)\) (<code>uint8</code>)</li>
<li>All floating-point inputs are <code>float32</code></li>
<li>Performance is measured with <code>B</code> = 32, <code>S</code> = 32,768, <code>D</code> = 128, <code>C</code> = 16</li>
</ul>