66 compressed keys and compute dot-product attention scores against full-precision queries.
77</ p >
88
9- < svg width ="660 " height ="300 " viewBox ="0 0 660 300 " xmlns ="http://www.w3.org/2000/svg "
10- style ="display:block; margin:20px auto; font-family:monospace; ">
11- < rect width ="660 " height ="300 " fill ="#222 " rx ="8 "/>
12-
13- <!-- Title -->
14- < text x ="330 " y ="20 " text-anchor ="middle " fill ="#888 " font-size ="10 "> TurboQuant Dequantization Pipeline (per key vector)</ text >
15-
16- <!-- ============================================================ -->
17- <!-- Stage 1: MSE dequantization (top row) -->
18- <!-- ============================================================ -->
19- < text x ="16 " y ="50 " fill ="#4a9edd " font-size ="10 " font-weight ="bold "> Stage 1: MSE</ text >
20-
21- <!-- K_idx -->
22- < rect x ="16 " y ="60 " width ="80 " height ="40 " rx ="4 " fill ="#1a3a5c " stroke ="#4a9edd " stroke-width ="1.5 "/>
23- < text x ="56 " y ="78 " text-anchor ="middle " fill ="#4a9edd " font-size ="10 "> K_idx</ text >
24- < text x ="56 " y ="92 " text-anchor ="middle " fill ="#7799bb " font-size ="8 "> [S, D] uint8</ text >
25-
26- <!-- arrow -->
27- < text x ="110 " y ="84 " fill ="#888 " font-size ="11 " text-anchor ="middle "> →</ text >
28-
29- <!-- codebook lookup -->
30- < rect x ="124 " y ="60 " width ="120 " height ="40 " rx ="4 " fill ="#1a1a1a " stroke ="#666 " stroke-width ="1 "/>
31- < text x ="184 " y ="78 " text-anchor ="middle " fill ="#ccc " font-size ="9 "> codebook[ K_idx ]</ text >
32- < text x ="184 " y ="92 " text-anchor ="middle " fill ="#888 " font-size ="8 "> centroid lookup</ text >
33-
34- <!-- arrow -->
35- < text x ="258 " y ="84 " fill ="#888 " font-size ="11 " text-anchor ="middle "> →</ text >
36-
37- <!-- multiply by Pi -->
38- < rect x ="272 " y ="60 " width ="80 " height ="40 " rx ="4 " fill ="#1a1a1a " stroke ="#666 " stroke-width ="1 "/>
39- < text x ="312 " y ="78 " text-anchor ="middle " fill ="#ccc " font-size ="9 "> × Π</ text >
40- < text x ="312 " y ="92 " text-anchor ="middle " fill ="#888 " font-size ="8 "> rotate back</ text >
41-
42- <!-- = -->
43- < text x ="366 " y ="84 " fill ="#888 " font-size ="11 " text-anchor ="middle "> =</ text >
44-
45- <!-- K_mse result -->
46- < rect x ="380 " y ="60 " width ="100 " height ="40 " rx ="4 " fill ="#1a3a5c " stroke ="#4a9edd " stroke-width ="1.5 "/>
47- < text x ="430 " y ="82 " text-anchor ="middle " fill ="#4a9edd " font-size ="10 "> K̃_mse</ text >
48- < text x ="430 " y ="94 " text-anchor ="middle " fill ="#7799bb " font-size ="8 "> [S, D] float32</ text >
49-
50- <!-- ============================================================ -->
51- <!-- Stage 2: QJL dequantization (middle row) -->
52- <!-- ============================================================ -->
53- < text x ="16 " y ="130 " fill ="#52b788 " font-size ="10 " font-weight ="bold "> Stage 2: QJL residual</ text >
54-
55- <!-- sigma (signs) -->
56- < rect x ="16 " y ="140 " width ="60 " height ="40 " rx ="4 " fill ="#1a3a2a " stroke ="#52b788 " stroke-width ="1.5 "/>
57- < text x ="46 " y ="158 " text-anchor ="middle " fill ="#52b788 " font-size ="10 "> σ</ text >
58- < text x ="46 " y ="172 " text-anchor ="middle " fill ="#77bbaa " font-size ="8 "> [S,D] ±1</ text >
59-
60- <!-- multiply by S_proj -->
61- < text x ="90 " y ="164 " fill ="#888 " font-size ="11 " text-anchor ="middle "> →</ text >
62-
63- < rect x ="104 " y ="140 " width ="60 " height ="40 " rx ="4 " fill ="#1a1a1a " stroke ="#666 " stroke-width ="1 "/>
64- < text x ="134 " y ="158 " text-anchor ="middle " fill ="#ccc " font-size ="9 "> σ · M</ text >
65- < text x ="134 " y ="172 " text-anchor ="middle " fill ="#888 " font-size ="8 "> project</ text >
66-
67- <!-- multiply by scale * gamma -->
68- < text x ="178 " y ="164 " fill ="#888 " font-size ="11 " text-anchor ="middle "> →</ text >
69-
70- < rect x ="192 " y ="140 " width ="120 " height ="40 " rx ="4 " fill ="#1a1a1a " stroke ="#666 " stroke-width ="1 "/>
71- < text x ="252 " y ="158 " text-anchor ="middle " fill ="#ccc " font-size ="9 "> × √(π/2)/D × γ</ text >
72- < text x ="252 " y ="172 " text-anchor ="middle " fill ="#888 " font-size ="8 "> scale by norm</ text >
73-
74- <!-- = -->
75- < text x ="326 " y ="164 " fill ="#888 " font-size ="11 " text-anchor ="middle "> =</ text >
76-
77- <!-- K_qjl result -->
78- < rect x ="340 " y ="140 " width ="100 " height ="40 " rx ="4 " fill ="#1a3a2a " stroke ="#52b788 " stroke-width ="1.5 "/>
79- < text x ="390 " y ="162 " text-anchor ="middle " fill ="#52b788 " font-size ="10 "> K̃_res</ text >
80- < text x ="390 " y ="174 " text-anchor ="middle " fill ="#77bbaa " font-size ="8 "> [S, D] float32</ text >
81-
82- <!-- ============================================================ -->
83- <!-- Combine + dot product (bottom) -->
84- <!-- ============================================================ -->
85- < text x ="16 " y ="210 " fill ="#e0a040 " font-size ="10 " font-weight ="bold "> Combine + Score</ text >
86-
87- <!-- K_mse + K_res -->
88- < rect x ="16 " y ="220 " width ="100 " height ="36 " rx ="4 " fill ="#1a3a5c " stroke ="#4a9edd " stroke-width ="1 "/>
89- < text x ="66 " y ="242 " text-anchor ="middle " fill ="#4a9edd " font-size ="10 "> K̃_mse</ text >
90-
91- < text x ="128 " y ="242 " fill ="#ccc " font-size ="14 " text-anchor ="middle "> +</ text >
92-
93- < rect x ="142 " y ="220 " width ="100 " height ="36 " rx ="4 " fill ="#1a3a2a " stroke ="#52b788 " stroke-width ="1 "/>
94- < text x ="192 " y ="242 " text-anchor ="middle " fill ="#52b788 " font-size ="10 "> K̃_res</ text >
95-
96- < text x ="256 " y ="242 " fill ="#ccc " font-size ="14 " text-anchor ="middle "> =</ text >
97-
98- < rect x ="270 " y ="220 " width ="80 " height ="36 " rx ="4 " fill ="#3a2a1a " stroke ="#e0a040 " stroke-width ="1.5 "/>
99- < text x ="310 " y ="242 " text-anchor ="middle " fill ="#e0a040 " font-size ="10 "> K̃</ text >
100-
101- <!-- dot with Q -->
102- < text x ="370 " y ="242 " fill ="#ccc " font-size ="10 " text-anchor ="middle "> then:</ text >
103-
104- < rect x ="394 " y ="220 " width ="70 " height ="36 " rx ="4 " fill ="#2a1a3a " stroke ="#c060e0 " stroke-width ="1.5 "/>
105- < text x ="429 " y ="242 " text-anchor ="middle " fill ="#c060e0 " font-size ="10 "> Q</ text >
106-
107- < text x ="478 " y ="242 " fill ="#ccc " font-size ="14 " text-anchor ="middle "> ·</ text >
108-
109- < rect x ="492 " y ="220 " width ="70 " height ="36 " rx ="4 " fill ="#3a2a1a " stroke ="#e0a040 " stroke-width ="1 "/>
110- < text x ="527 " y ="242 " text-anchor ="middle " fill ="#e0a040 " font-size ="10 "> K̃ᵀ</ text >
111-
112- < text x ="576 " y ="242 " fill ="#ccc " font-size ="14 " text-anchor ="middle "> =</ text >
113-
114- < rect x ="590 " y ="220 " width ="56 " height ="36 " rx ="4 " fill ="#3a1a1a " stroke ="#e05050 " stroke-width ="1.5 "/>
115- < text x ="618 " y ="242 " text-anchor ="middle " fill ="#e05050 " font-size ="10 "> scores</ text >
116-
117- <!-- Legend -->
118- < text x ="16 " y ="280 " fill ="#888 " font-size ="9 "> Π = orthogonal rotation [D×D]</ text >
119- < text x ="200 " y ="280 " fill ="#888 " font-size ="9 "> M = Gaussian projection [D×D]</ text >
120- < text x ="420 " y ="280 " fill ="#888 " font-size ="9 "> σ = sign bits ±1, γ = ‖residual‖₂</ text >
121- </ svg >
122-
1239< p >
124- < strong > Background — how the keys were compressed</ strong > (already done for you, not part of the challenge):
10+ < strong > Background - how the keys were compressed</ strong > (already done for you, not part of the challenge):
12511</ p >
12612< ol >
12713 < li > < strong > Rotate</ strong > : multiply key by orthogonal matrix \(\Pi\): \(\;y = \Pi \cdot K\). This makes each
13117 < li > < strong > Residual correction</ strong > : MSE quantization loses information. Compute the residual
13218 \(r = K - \tilde{K}_\text{mse}\), then store:
13319 < ul >
134- < li > \(\sigma = \text{sign}(M \cdot r) \in \{-1,+1\}^D\) — direction (< code > int8</ code > )</ li >
135- < li > \(\gamma = \|r\|_2\) — magnitude (< code > float32</ code > scalar per key)</ li >
20+ < li > \(\sigma = \text{sign}(S_\text{mat} \cdot r) \in \{-1,+1\}^D\) - direction (< code > int8</ code > )</ li >
21+ < li > \(\gamma = \|r\|_2\) - magnitude (< code > float32</ code > scalar per key)</ li >
13622 </ ul >
137- where \(M \in \mathbb{R}^{D \times D}\) is a random Gaussian projection matrix ( < code > S_mat </ code > in code) .
23+ where \(S_\text{mat} \in \mathbb{R}^{D \times D}\) is a random Gaussian projection matrix.
13824 </ li >
13925</ ol >
14026
14127< p >
142- < strong > What you compute</ strong > — dequantize and score:
28+ < strong > What you compute</ strong > - dequantize and score:
14329</ p >
14430< ol >
14531 < li > < strong > MSE dequantize</ strong > : look up centroids, undo the rotation:
14632 \[\tilde{K}_\text{mse} = \text{codebook}[K_\text{idx}] \cdot \Pi\]</ li >
14733 < li > < strong > Residual dequantize</ strong > : reconstruct the residual correction:
148- \[\tilde{K}_\text{res} = \frac{\sqrt{\pi/2}}{D} \cdot \gamma \cdot \sigma \cdot M \]
34+ \[\tilde{K}_\text{res} = \frac{\sqrt{\pi/2}}{D} \cdot \gamma \cdot \sigma \cdot S_\text{mat} \]
14935 The \(\sqrt{\pi/2}/D\) constant corrects for the distortion introduced by taking signs.</ li >
15036 < li > < strong > Combine</ strong > :
15137 \(\tilde{K} = \tilde{K}_\text{mse} + \tilde{K}_\text{res}\)</ li >
@@ -166,19 +52,20 @@ <h2>Implementation Requirements</h2>
16652
16753< h2 > Example</ h2 >
16854< p >
169- Input: \(B=2,\; S=3,\; D=2,\; C=4\), with \(\Pi = I\), \(M = I\), \(\gamma = \mathbf{0}\) (residual correction disabled):
55+ Input: \(B=2,\; S=3,\; D=2,\; C=4\), with \(\Pi = I\), \(S_\text{mat} = I\), \(\gamma = \mathbf{0}\) (residual correction disabled),
56+ \(\sigma = \mathbf{1}\) (all +1):
17057</ p >
17158< p >
172- \(Q = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}\), \quad
173- \(K_\text{idx} = \begin{bmatrix} 0 & 3 \\ 1 & 2 \\ 3 & 0 \end{bmatrix}\), \quad
59+ \(Q = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}\),
60+ \(K_\text{idx} = \begin{bmatrix} 0 & 3 \\ 1 & 2 \\ 3 & 0 \end{bmatrix}\),
17461 codebook \(= [-0.75,\; -0.25,\; 0.25,\; 0.75]\)
17562</ p >
17663< p >
177- Step 1 — MSE lookup and rotate back (\(\Pi = I\)):
64+ Step 1 - MSE lookup and rotate back (\(\Pi = I\)):
17865 \[
17966 \tilde{K}_\text{mse} = \begin{bmatrix} -0.75 & 0.75 \\ -0.25 & 0.25 \\ 0.75 & -0.75 \end{bmatrix}
18067 \]
181- Step 2 — Residual correction is zero (\(\gamma = 0\)), so \(\tilde{K} = \tilde{K}_\text{mse}\).
68+ Step 2 - Residual correction is zero (\(\gamma = 0\)), so \(\tilde{K} = \tilde{K}_\text{mse}\).
18269</ p >
18370< p >
18471 Output:
@@ -194,7 +81,8 @@ <h2>Constraints</h2>
19481 < li > 1 ≤ < code > D</ code > ≤ 256</ li >
19582 < li > 2 ≤ < code > C</ code > ≤ 256</ li >
19683 < li > \(\Pi\) is orthogonal (\(\Pi^T \Pi = I\))</ li >
197- < li > \(M\) (< code > S_mat</ code > ) has i.i.d. \(\mathcal{N}(0,1)\) entries</ li >
84+ < li > < code > S_mat</ code > has i.i.d. \(\mathcal{N}(0,1)\) entries</ li >
85+ < li > < code > gamma</ code > has shape \([S]\) (one \(\ell_2\) norm per key vector, < code > float32</ code > )</ li >
19886 < li > < code > qjl_signs</ code > (\(\sigma\)) values are in \(\{-1, +1\}\) (< code > int8</ code > )</ li >
19987 < li > < code > K_idx</ code > values are in \([0, C)\) (< code > uint8</ code > )</ li >
20088 < li > All floating-point inputs are < code > float32</ code > </ li >
0 commit comments