@@ -68,9 +68,130 @@ pub fn clear_attention_cache() {
6868}
6969
7070// ============================================================================
71- // VNNI u8 MatVec fast path — 64 MACs per instruction
71+ // Compiled Linear Cache — O(k) replacing O(n_rows) for any weight matrix
7272// ============================================================================
7373//
74+ // For any linear layer y = W @ x, where W is [n_rows, n_cols]:
75+ // 1. Each row of W is assigned to one of 256 palette centroids (u8 index)
76+ // 2. At inference: compute k=256 centroid dot products with input x
77+ // 3. For each output row i: y[i] = centroid_outputs[assignment[i]]
78+ //
79+ // Cost: 256 × n_cols MACs + n_rows lookups (vs n_rows × n_cols MACs)
80+ // For gate_proj [3072, 1024]: 256K MACs vs 3.1M MACs = 12× fewer.
81+ //
82+ // Keyed by (n_rows, n_cols) — the weight matrix shape.
83+
84+ /// A compiled linear layer: 256 centroids replace the full weight matrix.
85+ #[ cfg( feature = "std" ) ]
86+ #[ derive( Clone ) ]
87+ pub struct CompiledLinear {
88+ /// Centroid weight vectors: [k × n_cols] f32, row-major.
89+ /// k=256 centroids, each of dimension n_cols.
90+ pub centroids : Vec < f32 > ,
91+ /// Number of centroids (palette size, typically 256).
92+ pub k : usize ,
93+ /// Input dimension (n_cols of the original weight matrix).
94+ pub n_cols : usize ,
95+ /// Output dimension (n_rows of the original weight matrix).
96+ pub n_rows : usize ,
97+ /// Row assignment: for each of the n_rows output rows, which centroid it maps to.
98+ pub assignments : Vec < u8 > ,
99+ }
100+
101+ /// Global cache of compiled linear layers.
102+ /// Keyed by (n_rows, n_cols) — the original weight matrix shape.
103+ /// Multiple layers can share the same shape, so we use a Vec and match by registration order.
104+ #[ cfg( feature = "std" ) ]
105+ static LINEAR_CACHE : LazyLock < RwLock < Vec < CompiledLinear > > > =
106+ LazyLock :: new ( || RwLock :: new ( Vec :: new ( ) ) ) ;
107+
108+ /// Register a compiled linear layer.
109+ #[ cfg( feature = "std" ) ]
110+ pub fn register_compiled_linear ( compiled : CompiledLinear ) {
111+ let mut cache = LINEAR_CACHE . write ( ) . unwrap ( ) ;
112+ cache. push ( compiled) ;
113+ }
114+
115+ /// Pop the next compiled linear for the given shape.
116+ /// Returns None if no matching table exists.
117+ /// This is FIFO — layers are consumed in registration order.
118+ #[ cfg( feature = "std" ) ]
119+ fn pop_compiled_linear ( n_rows : usize , n_cols : usize ) -> Option < CompiledLinear > {
120+ let cache = LINEAR_CACHE . read ( ) . unwrap ( ) ;
121+ // Find first matching entry (don't pop — layers may be reused across batches)
122+ cache. iter ( ) . find ( |c| c. n_rows == n_rows && c. n_cols == n_cols) . cloned ( )
123+ }
124+
125+ /// Try to compute y = W @ x using compiled centroid matmul.
126+ ///
127+ /// Instead of n_rows × n_cols MACs:
128+ /// 1. Compute 256 centroid outputs: centroid_out[c] = dot(centroid[c], x)
129+ /// 2. For each output row i: out[i] = centroid_out[assignment[i]]
130+ ///
131+ /// Returns true if compiled path was used.
132+ #[ cfg( feature = "std" ) ]
133+ fn try_compiled_linear < E : NdArrayElement > (
134+ lhs : & ndarray:: ArrayView2 < ' _ , E > ,
135+ _rhs : & ndarray:: ArrayView2 < ' _ , E > ,
136+ out : & mut ndarray:: ArrayViewMut2 < ' _ , E > ,
137+ m : usize ,
138+ k_dim : usize ,
139+ n : usize ,
140+ ) -> bool {
141+ // The weight matrix is lhs [m, k_dim], input is rhs [k_dim, n]
142+ // Output is [m, n]
143+ let compiled = match pop_compiled_linear ( m, k_dim) {
144+ Some ( c) => c,
145+ None => return false ,
146+ } ;
147+
148+ if compiled. assignments . len ( ) < m || compiled. k == 0 {
149+ return false ;
150+ }
151+
152+ // Step 1: compute centroid outputs for each input column
153+ // centroid_out[c][j] = dot(centroid[c], rhs[:, j])
154+ // For n=1 (typical MLP): just one dot product per centroid
155+ let k = compiled. k ;
156+
157+ // Extract rhs as contiguous f32 for dot products
158+ // rhs is [k_dim, n], we need column vectors
159+ for j in 0 ..n {
160+ // Compute centroid outputs for column j
161+ let mut centroid_out = vec ! [ 0.0f64 ; k] ;
162+ for c in 0 ..k {
163+ let centroid_row = & compiled. centroids [ c * compiled. n_cols ..] [ ..compiled. n_cols ] ;
164+ let mut dot = 0.0f64 ;
165+ for d in 0 ..compiled. n_cols . min ( k_dim) {
166+ let rhs_val: f64 = _rhs[ [ d, j] ] . elem ( ) ;
167+ dot += centroid_row[ d] as f64 * rhs_val;
168+ }
169+ centroid_out[ c] = dot;
170+ }
171+
172+ // Step 2: broadcast via palette assignment
173+ for i in 0 ..m {
174+ let c_idx = compiled. assignments [ i] as usize ;
175+ let val = centroid_out[ c_idx. min ( k - 1 ) ] ;
176+ out[ [ i, j] ] = val. elem ( ) ;
177+ }
178+ }
179+
180+ true
181+ }
182+
183+ /// Count of registered compiled linear layers.
184+ #[ cfg( feature = "std" ) ]
185+ pub fn compiled_linear_count ( ) -> usize {
186+ LINEAR_CACHE . read ( ) . unwrap ( ) . len ( )
187+ }
188+
189+ /// Clear all compiled linear layers.
190+ #[ cfg( feature = "std" ) ]
191+ pub fn clear_compiled_linear_cache ( ) {
192+ LINEAR_CACHE . write ( ) . unwrap ( ) . clear ( ) ;
193+ }
194+ //
74195// For quantized u8×i8 matmul (codebook distance table build):
75196// Input A: [m, k] u8 (codebook rows, quantized)
76197// Input B: [k, n] i8 (codebook cols, quantized)
@@ -355,6 +476,13 @@ pub(crate) fn matmul<E: NdArrayElement>(
355476 . get( )
356477 . slice_mut( s!( out_batch, .., ..) ) ;
357478
479+ // Try compiled linear (centroid matmul, O(256) per column).
480+ // Falls through to BLAS if no compiled layer matches.
481+ #[ cfg( feature = "std" ) ]
482+ if try_compiled_linear( & lhs_slice, & rhs_slice, & mut out_slice, m, k, n) {
483+ return ;
484+ }
485+
358486 // Try compiled attention table (O(1) per element).
359487 // Falls through to BLAS if no table is registered for d_head=k.
360488 #[ cfg( feature = "std" ) ]
0 commit comments