|
| 1 | +# RSR Multiplier Optimization Guide |
| 2 | + |
| 3 | +How the RSR multipliers in this repository are optimized and why they are fast. |
| 4 | + |
| 5 | +**Scope:** RSR implementations under `multiplier/` and `kernels/`. |
| 6 | + |
| 7 | +--- |
| 8 | + |
| 9 | +## The Core Idea |
| 10 | + |
| 11 | +Every RSR multiplier does this: |
| 12 | + |
| 13 | +1. **Preprocess** the weight matrix once. |
| 14 | +2. **Group** columns with identical `k`-row block patterns. |
| 15 | +3. At inference, **aggregate** the input values for each group once. |
| 16 | +4. **Scatter** that sum to the affected output rows. |
| 17 | + |
| 18 | +Naive low-bit matvec touches every matrix entry. RSR touches every *unique column pattern per block* instead. Everything else in this document is about reducing the overhead around that algorithmic win. |
| 19 | + |
| 20 | +## Preprocessing |
| 21 | + |
| 22 | +For a binary block, each column becomes a `k`-bit integer. For a ternary block, each column becomes two `k`-bit integers (positive mask, negative mask), combined into a `2k`-bit ternary code. Columns with the same code share one aggregate at inference time. |
| 23 | + |
| 24 | +The fastest preprocessing uses **counting sort** over the discrete pattern space (`2^k` buckets for binary, `4^k` for ternary), giving `O(n + buckets)` per block instead of `O(n log n)`. |
| 25 | + |
| 26 | +## Why `k` Limits Exist |
| 27 | + |
| 28 | +- Binary counting sort: `2^k` buckets — impractical for very large `k` |
| 29 | +- Ternary counting sort: `4^k` buckets — grows much faster |
| 30 | +- Bitmask-scatter variants store row membership in `uint16` → requires `k ≤ 16` |
| 31 | +- 16-bit permutation indices require column count ≤ 65535 |
| 32 | + |
| 33 | +These are not arbitrary guardrails — they are what enable compact metadata and cheap inner loops. |
| 34 | + |
| 35 | +--- |
| 36 | + |
| 37 | +## Binary CPU: `multiplier/bit_1` |
| 38 | + |
| 39 | +### `RSRPythonMultiplier` |
| 40 | + |
| 41 | +*Files: `multiplier/bit_1/rsr_py.py`* |
| 42 | + |
| 43 | +Pure Python/PyTorch reference. Encodes columns into integers, comparison-sorts by pattern, finds unique patterns, aggregates with `scatter_add_`, and scatters via `unique_bits.T @ aggregated`. Proves the algorithm; not the fastest due to PyTorch dispatch overhead and general-purpose sorting. |
| 44 | + |
| 45 | +### `RSRCppMultiplier` |
| 46 | + |
| 47 | +*Files: `multiplier/bit_1/cpu/rsr_cpp.py`, `kernels/bit_1/cpu/rsr_prep.c`, `kernels/bit_1/cpu/rsr.c`* |
| 48 | + |
| 49 | +Moves the hot path to C. Preprocessing uses counting sort (parallelized with OpenMP). Inference is a single fused kernel for gather, aggregate, and scatter. Long groups use AVX2 `_mm256_i32gather_ps` with 4× unrolling. |
| 50 | + |
| 51 | +### `RSRCppV2_4Multiplier` |
| 52 | + |
| 53 | +*Files: `multiplier/bit_1/cpu/rsr_cpp_v2_4.py`, `kernels/bit_1/cpu/rsr_v2_4.c`* |
| 54 | + |
| 55 | +Adds `schedule(static)` OpenMP and pre-allocates one 64-byte-aligned `v_perm` buffer per thread. The kernel first gathers `v[perm]` into the contiguous buffer, then aggregates over contiguous slices. Trades extra buffer traffic for a cleaner aggregation phase — works well when contiguous summation outweighs the write cost. |
| 56 | + |
| 57 | +### `RSRCppV4_2Multiplier` |
| 58 | + |
| 59 | +*Files: `multiplier/bit_1/cpu/rsr_cpp_v4_2.py`, `kernels/bit_1/cpu/rsr_v4_2.c`* |
| 60 | + |
| 61 | +Removes the `v_perm` buffer entirely. Gathers and aggregates in one pass directly from `v`. Uses scalar-unrolled loads (8-way switch) with `_mm_prefetch` at 64-element distance instead of AVX2 gather. Less memory traffic, less temporary storage, better cache behavior on the actual bottleneck (random reads from `v`). |
| 62 | + |
| 63 | +**This is the key binary CPU kernel.** |
| 64 | + |
| 65 | +### `RSRCppNonSquareMultiplier` |
| 66 | + |
| 67 | +*Files: `multiplier/bit_1/cpu/rsr_cpp_nonsquare.py`, `kernels/bit_1/cpu/rsr_prep_nonsquare.c`* |
| 68 | + |
| 69 | +Supports `n_rows × n_cols` matrices by padding rows to a multiple of `k`, running non-square preprocessing, and reusing the v4.2 inference kernel with `n_cols` as the permutation stride. No second inference kernel — just metadata adaptation. |
| 70 | + |
| 71 | +### `RSRAdaptiveMultiplier` |
| 72 | + |
| 73 | +*Files: `multiplier/bit_1/cpu/rsr_adaptive.py`* |
| 74 | + |
| 75 | +For square matrices not divisible by `k`, pads to the next multiple and delegates to `RSRCppV4_2Multiplier`. |
| 76 | + |
| 77 | +--- |
| 78 | + |
| 79 | +## Binary CUDA: `multiplier/bit_1/cuda` |
| 80 | + |
| 81 | +### Shared Preprocessing |
| 82 | + |
| 83 | +*Files: `_prep_cuda.py`, `_prep_cuda_nonsquare.py`* |
| 84 | + |
| 85 | +Preprocessing runs on CPU (counting sort is one-time work). Metadata is rearranged into GPU-friendly tensors. Several versions sort permutation indices within each group — this does not change the sum but makes reads from `v` more spatially local, improving L2 behavior. |
| 86 | + |
| 87 | +### `RSRCudaV4_10Multiplier` |
| 88 | + |
| 89 | +*Files: `rsr_cuda_v4_10.py`, `kernels/bit_1/cuda/rsr_v4_10.cu`* |
| 90 | + |
| 91 | +- 16-bit permutations |
| 92 | +- Precomputed `group_starts` |
| 93 | +- Sorted perms within groups |
| 94 | +- One CUDA block per row block; warps process groups, lane 0 scatters to shared memory |
| 95 | +- 8× unrolled gather (256 elements per iteration) |
| 96 | +- Adaptive thread count: 128/256/512 based on `k` |
| 97 | + |
| 98 | +### `RSRCudaV5_7Multiplier` |
| 99 | + |
| 100 | +*Files: `rsr_cuda_v5_7.py`, `kernels/bit_1/cuda/rsr_v5_6.cu`* |
| 101 | + |
| 102 | +Introduces **packed metadata**: each group is one `int4(start, end, row_mask, 0)`. Scatter becomes bit operations over the row mask instead of following a variable-length row array. Each warp writes into its own shared-memory partial buffer. Processes two groups per warp step. Fixed 256 threads/block. |
| 103 | + |
| 104 | +One metadata load per group instead of multiple array reads → less global memory traffic. |
| 105 | + |
| 106 | +### `RSRCudaV5_8Multiplier` |
| 107 | + |
| 108 | +*Files: `rsr_cuda_v5_8.py`, `kernels/bit_1/cuda/rsr_v5_8.cu`* |
| 109 | + |
| 110 | +Same packed-metadata design as v5.7. Uses 1024 threads when `k > 4` (256 otherwise) and builds with `--use_fast_math`. More warps per block means more group-level parallelism on larger `k`. |
| 111 | + |
| 112 | +### `RSRCudaV5_9Multiplier` |
| 113 | + |
| 114 | +*Files: `rsr_cuda_v5_9.py`, `kernels/bit_1/cuda/rsr_v5_9.cu`* |
| 115 | + |
| 116 | +Keeps packed metadata and sorted perms. Stores permutations as `uint16` — at large `n`, the permutation array is the biggest metadata stream, so halving it directly lowers bandwidth pressure. Uses 256 threads for `k ≤ 4`, 512 otherwise. |
| 117 | + |
| 118 | +**This is the main large-`n` binary CUDA kernel.** |
| 119 | + |
| 120 | +### `RSRCudaV5_9NonSquareMultiplier` |
| 121 | + |
| 122 | +*Files: `rsr_cuda_v5_9_nonsquare.py`* |
| 123 | + |
| 124 | +Pads rows to a multiple of `k`, runs non-square CPU preprocessing, sorts within groups, and reuses the v5.9 kernel with `n_cols` as stride. |
| 125 | + |
| 126 | +### `RSRCudaV5_10Multiplier` |
| 127 | + |
| 128 | +*Files: `rsr_cuda_v5_10.py`* |
| 129 | + |
| 130 | +Not a new kernel — an empirical dispatcher: |
| 131 | + |
| 132 | +| Condition | Kernel | |
| 133 | +|:---|:---| |
| 134 | +| `k == 8` and `n ≤ 4096` | v5.7 | |
| 135 | +| `k == 16` and `n ≤ 8192` | v5.8 | |
| 136 | +| otherwise | v5.9 | |
| 137 | + |
| 138 | +### `RSRCudaAdaptiveMultiplier` |
| 139 | + |
| 140 | +*Files: `rsr_cuda_adaptive.py`* |
| 141 | + |
| 142 | +Pads square matrices to a multiple of `k` and delegates to `RSRCudaV5_10Multiplier`. |
| 143 | + |
| 144 | +--- |
| 145 | + |
| 146 | +## Ternary CPU: `multiplier/bit_1_58` |
| 147 | + |
| 148 | +### What Is Different |
| 149 | + |
| 150 | +Binary RSR only needs to know which rows receive `+agg`. Ternary RSR must track both `+agg` and `-agg` rows. The optimization story in the ternary family is reducing the cost of storing and reading this signed scatter information. |
| 151 | + |
| 152 | +### `RSRTernaryV1_4Multiplier` |
| 153 | + |
| 154 | +*Files: `multiplier/bit_1_58/cpu/rsr_v1_4.py`, `kernels/bit_1_58/cpu/rsr_ternary_prep.c`, `kernels/bit_1_58/cpu/rsr_ternary.c`* |
| 155 | + |
| 156 | +Splits each block into positive and negative bit patterns, combines into a `2k`-bit ternary code (`(pos_val << k) | neg_val`), groups with counting sort over `4^k` buckets. Inference gathers, sums, then scatters with explicit `scatter_rows` and `scatter_signs` arrays. Fused C, not Python. |
| 157 | + |
| 158 | +### `RSRTernaryV3_1Multiplier` |
| 159 | + |
| 160 | +*Files: `multiplier/bit_1_58/cpu/rsr_v3_1.py`, `kernels/bit_1_58/cpu/rsr_ternary_v3_1.c`* |
| 161 | + |
| 162 | +Over v1.4: |
| 163 | +- `perms` and `group_ends` shrunk to `uint16` — halves metadata bandwidth in the hot loop |
| 164 | +- Wrapper caches ctypes pointers (no repeated Python→ctypes setup) |
| 165 | +- Kernel uses `schedule(static)`, small-group fast paths (switch for len ≤ 4), and `_mm_prefetch` with T0 hints |
| 166 | + |
| 167 | +### `RSRTernaryV3_3Multiplier` |
| 168 | + |
| 169 | +*Files: `multiplier/bit_1_58/cpu/rsr_v3_3.py`, `kernels/bit_1_58/cpu/rsr_ternary_v3_3.c`* |
| 170 | + |
| 171 | +Replaces the variable-length signed scatter arrays with two fixed-size `uint16` masks per group: `pos_mask` and `neg_mask`. The kernel iterates set bits with `__builtin_ctz` and `mask &= mask - 1`. Requires `k ≤ 16`. |
| 172 | + |
| 173 | +This is **the key ternary CPU optimization**: metadata per group stops depending on how many rows are active. Two compact masks replace a variable-length scatter list. |
| 174 | + |
| 175 | +### `RSRTernaryNonSquareMultiplier` |
| 176 | + |
| 177 | +*Files: `multiplier/bit_1_58/cpu/rsr_nonsquare.py`, `kernels/bit_1_58/cpu/rsr_ternary_prep_nonsquare.c`* |
| 178 | + |
| 179 | +Pads rows to a multiple of `k`. Dispatches: |
| 180 | + |
| 181 | +| Condition | Kernel | |
| 182 | +|:---|:---| |
| 183 | +| `n_cols ≥ 4096` and `k ≤ 16` | v3.3 | |
| 184 | +| otherwise | v3.1 | |
| 185 | + |
| 186 | +v3.3 wins when metadata bandwidth matters enough to justify mask creation; v3.1 is lighter for smaller shapes. |
| 187 | + |
| 188 | +### CPU Runtime Variants for Model Inference |
| 189 | + |
| 190 | +*Files: `multiplier/bit_1_58/cpu/rsr_runtime.py`, `kernels/bit_1_58/cpu/rsr_ternary_v3_1_batch.c`, `kernels/bit_1_58/cpu/rsr_ternary_v3_3_batch.c`* |
| 191 | + |
| 192 | +These explain the large end-to-end CPU inference speedups. |
| 193 | + |
| 194 | +**`RSRPreprocessedMultiplier`** — loads saved RSR tensors from safetensors, skipping preprocessing at serve time. Dispatches to v3.1 or v3.3 based on the same rules. |
| 195 | + |
| 196 | +**`fused_call`** — fuses BitNet activation quantization and RSR GEMV into a single C call. Since the kernel already touches the full input vector, quantizing it in the same call avoids extra Python dispatch and temporary handling. |
| 197 | + |
| 198 | +**`RSRBatchMultiplier` / `RSRBatchMultiplierV31`** — batch multiple layers sharing the same input vector (e.g. `q_proj + k_proj + v_proj`, or `gate_proj + up_proj`). Quantize the input once, execute all GEMVs in one C call, and parallelize across the combined block pool with OpenMP. |
| 199 | + |
| 200 | +This batching and fusion layer is a major reason the CPU LLM path is much faster than calling one optimized GEMV per layer: |
| 201 | +- One quantization instead of many |
| 202 | +- One native call instead of many |
| 203 | +- More total blocks for OpenMP to distribute across cores |
| 204 | + |
| 205 | +--- |
| 206 | + |
| 207 | +## Ternary CUDA: `multiplier/bit_1_58/cuda` |
| 208 | + |
| 209 | +### Shared Preprocessing |
| 210 | + |
| 211 | +*Files: `_prep_cuda.py`, `_prep_cuda_nonsquare.py`, `_prep_v2_common.py`* |
| 212 | + |
| 213 | +Preprocessing runs on CPU. Then: |
| 214 | +1. Sort permutations inside each group |
| 215 | +2. Convert signed scatter lists into `pos_mask` / `neg_mask` |
| 216 | +3. Drop all-zero groups (they never change the output) |
| 217 | +4. Pack each remaining group into one `uint64`: |
| 218 | + |
| 219 | +``` |
| 220 | +bits [0:15] start |
| 221 | +bits [16:31] length |
| 222 | +bits [32:47] pos_mask |
| 223 | +bits [48:63] neg_mask |
| 224 | +``` |
| 225 | + |
| 226 | +One 64-bit word fully describes a group. The kernel reads one value and starts working immediately. Skipping zero groups removes useless warps and gathers. |
| 227 | + |
| 228 | +### `RSRTernaryCudaV2_0Multiplier` |
| 229 | + |
| 230 | +*Files: `rsr_cuda_v2_0.py`, `kernels/bit_1_58/cuda/rsr_ternary_v2_0.cu`* |
| 231 | + |
| 232 | +- One CUDA block per row block |
| 233 | +- One warp processes one group: reduces the group sum, then scatters into per-warp shared-memory partials using the positive/negative masks |
| 234 | +- Partials reduced across warps at block end |
| 235 | +- Zero groups removed during preprocessing |
| 236 | +- 16-bit permutations |
| 237 | +- Compile-time `k` specializations for 2, 4, 6, 8, 10, 12 — lets the compiler unroll the scatter loop |
| 238 | +- Thread count: 256 for `k ≤ 4`, 512 otherwise |
| 239 | + |
| 240 | +**This is the retained ternary CUDA kernel.** |
| 241 | + |
| 242 | +### `RSRPreprocessedCudaMultiplier` |
| 243 | + |
| 244 | +*Files: `multiplier/bit_1_58/cuda/rsr_runtime.py`* |
| 245 | + |
| 246 | +Loads saved CUDA RSR tensors and reuses the v2.0 kernel at inference time. Unlike the CPU path, the current CUDA runtime does not yet fuse activation quantization across sibling layers — the speedup comes from the compact v2.0 kernel itself. |
| 247 | + |
| 248 | +--- |
| 249 | + |
| 250 | +## Currently Active Paths |
| 251 | + |
| 252 | +| Use case | Entry point | Inner kernel | |
| 253 | +|:---|:---|:---| |
| 254 | +| Binary CPU | `RSRCppNonSquareMultiplier` | v4.2 | |
| 255 | +| Binary CUDA | `RSRCudaV5_9NonSquareMultiplier` | v5.10 adaptive dispatcher | |
| 256 | +| Ternary CPU | `RSRTernaryNonSquareMultiplier` | v3.1 or v3.3 | |
| 257 | +| Ternary CPU (model runtime) | `RSRPreprocessedMultiplier` | fused batch v3.1/v3.3 | |
| 258 | +| Ternary CUDA | `RSRTernaryCudaV2_0Multiplier` | v2.0 | |
| 259 | + |
| 260 | +--- |
| 261 | + |
| 262 | +## Optimization Progression |
| 263 | + |
| 264 | +The speed comes from this sequence: |
| 265 | + |
| 266 | +1. **Change the algorithm.** Aggregate repeated column patterns once instead of repeating work per column. |
| 267 | +2. **Match preprocessing to the discrete problem.** Counting sort over a small known pattern space beats general-purpose sorting. |
| 268 | +3. **Shrink the hot metadata.** `uint16` perms, `uint16` masks, packed 64-bit words, dropped zero groups — all cut bandwidth. |
| 269 | +4. **Improve gather locality.** Sorting indices within a group does not change the sum but makes the memory system happier. |
| 270 | +5. **Remove unnecessary movement.** Direct-gather designs delete the extra write/read of intermediate buffers. |
| 271 | +6. **Keep updates local.** Shared-memory or per-warp partials avoid fighting over global output writes. |
| 272 | +7. **Fuse surrounding work.** Skipping repeated activation quantization and repeated Python calls matters almost as much as the GEMV kernel itself. |
0 commit comments