feat: MLA absorption for DeepSeek V2/V3 — fuse low-rank Q/K/V into standard dense tensors#96
Conversation
|
Hey @mvkorobkov — same situation as #103: branch is conflicting against current Could you rebase against current main? If you'd rather, I can also cherry-pick onto a fresh branch with your attribution preserved on the commits. Let me know which you prefer. Once rebased, I'll do a proper review of the gqa_attention_asym kernel + the DeepSeek geometry plumbing. |
DS-V3 absorbed attention has qk_head_dim=192 (nope=128+rope=64) but v_head_dim=128. The existing gqa_attention uses a single head_dim for all projections, which would corrupt V slicing and output shape. gqa_attention_asym accepts separate qk_head_dim and v_head_dim: - Q/K sliced with qk_head_dim (dot-product stays in the larger space) - V sliced and output written with v_head_dim - Returns (seq, num_q * v_head_dim) When qk_head_dim == v_head_dim the function is numerically identical to gqa_attention (verified by asym_sym_equivalence_when_dims_equal test). 4 tests added: shape, finiteness, sym-equivalence, seq=1 causal. Note: gqa kernels live in larql-compute (post-ADR-0022 Step 2d); this commit places the asym variant alongside the existing gqa_attention there.
Three new optional fields on ModelConfig: qk_nope_head_dim — non-RoPE part of Q/K head dim (DS-V3: 128) qk_rope_head_dim — RoPE-rotated part of Q/K head dim (DS-V3: 64) v_head_dim — V projection head dim (DS-V3: 128) Parsed from config.json (qk_nope_head_dim / qk_rope_head_dim / v_head_dim). Trait accessors added to ModelArchitecture with None defaults. DeepSeekArch overrides to read from config. DS-V3 detection test extended to verify all three fields round-trip. Two GGUF test-only ModelConfig literals updated to include None stubs.
…eight matrices Implements `mla_absorb::absorb()` which converts the four MLA weight matrices (kv_a, kv_b, q_a, q_b) into standard dense Q/K/V tensors compatible with `gqa_attention_asym`. Key correctness points: - rope-K is MQA: single row in kv_a[kv_lora..] replicated num_kv times in absorbed K (not per-head in the input tensor) - DS-V3 native per-head layout [nope|rope] → LARQL convention [rope|nope] applied symmetrically to Q and K during absorption - V: straightforward kv_b[nope+v_hd slice] @ kv_compress Three tests (3 passed): - absorbed_forward_matches_reference: reference MLA forward vs absorbed path through gqa_attention_asym must match within 1e-4 - absorbed_shapes: output tensor dimensions - rope_k_is_broadcast_not_zero: single rope-K correctly replicated across heads
|
Rebased onto current main (810f163). Branch now contains 5 focused MLA commits (656 insertions, 10 deletions):
Dropped two commits that didn't belong here:
|
597d2ca to
2d10daa
Compare
write_model_weights_with_opts now accepts DS-V3 / MLA architectures when all three geometry fields (qk_nope_head_dim, qk_rope_head_dim, v_head_dim) are present in config.json. When detected: - skips the standard-attention guard - per layer: fetches kv_a/kv_b/q_a/q_b projections, calls mla_absorb::absorb, writes the resulting dense Q/K/V under the standard attn_q/k/v key names - O projection is passed through unchanged (no absorption needed) The loader remains MLA-unaware: it reads standard Q/K/V tensors just as for any Llama/Mistral model. The extra storage cost (absorbed K replicates the MQA rope-K row num_kv times) is acceptable for DS-V3 full scale (~3.5 GB extra per 61 layers on num_kv=128). All 971 larql-vindex unit + integration tests pass.
#67 llama.cpp emits DeepSeek-V2/V3 (and Kimi K2) MLA geometry in the GGUF metadata under {arch}.attention.* and {arch}.rope.dimension_count. `to_config_json` was dropping every one of these fields, so the parsed ModelConfig had MLA disabled and PR #96's absorption never fired for GGUF-sourced inputs. This surfaces the relevant fields into the HF-shaped config the parser consumes: - `attention.q_lora_rank` → `q_lora_rank` - `attention.kv_lora_rank` → `kv_lora_rank` - `attention.key_length[_mla]` → `qk_nope_head_dim` (= key_length − rope.dim) - `attention.value_length[_mla]`→ `v_head_dim` - `rope.dimension_count` → `qk_rope_head_dim` For per-head dims the loader prefers the `_mla` variants when present — those carry the pre-absorption (DS-V3-standard) split that `mla_absorb::absorb` operates on. Kimi K2.6's GGUF exposes both forms (192/128 for `_mla`, 576/512 absorbed); we want 192/128. Verified against Kimi K2.6 UD-Q8_K_XL GGUF metadata (the unsloth name is misleading — actual tensor types are BF16 + F32 + Q4_0, all already supported by larql's existing dequant). Three new tests cover: 1. Kimi K2.6-shaped metadata → full MLA fields populated, MLA detected 2. Non-`_mla` variant fallback (DS-V2 with key_length only) 3. Non-MLA architectures (llama) keep their fields absent 281/281 larql-models tests pass. Combined with PR #96 + #103 + #133, this unlocks inference-level extraction of Kimi K2 family and any other DeepSeek-V2/V3 GGUF that exposes the standard MLA metadata.
llama.cpp's gguf-split produces multi-file GGUFs (canonical naming: `<prefix>-<NNNNN>-of-<NNNNN>.gguf`). Each shard carries the full metadata header but only owns its own slice of tensors. The current `GgufFile::open` reads one file, so multi-shard models — Kimi K2.6 (14 shards), DeepSeek-V4-Flash (3 shards), and increasingly any large modern LLM — could not be loaded for vindex extraction. This change: 1. Adds `ShardInfo` (path + data_offset) and a `shards: Vec<ShardInfo>` field on `GgufFile`. Single-file GGUFs get a `shards.len() == 1`. 2. `GgufFile::open` detects multi-shard via the explicit `split.count` metadata key, falling back to the filename pattern when the splitter omits the metadata. 3. Discovers all sibling shards in the same directory by reconstructing filenames at the prefix's chosen width (`00001-of-00014` vs `001-of-003` both supported). 4. Appends each sibling's `tensor_infos` to the combined list, tagging them with the right `shard_idx`. Cross-checks the total against `split.tensors.count` when present. 5. `load_tensors_filtered` mmaps each shard lazily on first use and reads each tensor from `shards[info.shard_idx].path` at the right per-shard `data_offset`. Shards whose tensors are all skipped by `skip_key` are never opened. Backward-compatible: existing `GgufFile::open` callers and the single-file test fixtures keep working with `shards = vec![…one…]`. Tests (8 new + all existing pass): - parse_shard_filename: canonical layout, plain `.gguf` rejection, mismatched widths rejection, 3-digit split width support - discover_shard_siblings: complete set discovery from any-position shard, error when sibling missing - open_multi_shard_combines_tensors_from_all_shards: builds two real 2-shard GGUFs with disjoint tensor sets, opens via either shard, verifies each tensor reads from its own shard's data section - open_rejects_multi_shard_when_a_shard_file_is_missing - existing 27 tests stay green; 286/286 larql-models tests pass Combined with #96 (MLA absorption), #103 (Q3_K/Q5_K dequant), #133 (GGUF extract input), and #135 (DeepSeek-V2/V3 MLA metadata reading), this completes the chain — `larql extract --level inference` works end-to-end on Kimi K2.6 UD-Q8_K_XL and DeepSeek-V4-Flash multi-shard GGUFs.
Summary
gqa_attention_asym— new attention kernel inlarql-inferencethat handles asymmetricqk_head_dim/v_head_dim(required for absorbed MLA tensors where Q/K use 192-dim heads but V uses 128-dim heads in DS-V3)ModelConfig—qk_nope_head_dim,qk_rope_head_dim,v_head_dimparsed fromconfig.json;DeepSeekArchexposes them via trait methodsmla_absorb— new module inlarql-vindexthat fuses the four DS-V2/V3 low-rank attention projections (kv_a,kv_b,q_a,q_b) into standard dense Q/K/V weight matriceswrite_model_weights— F32 weight writer now accepts MLA architectures: detects full geometry, runs absorption per layer, writes absorbed Q/K/V under standard key names so the loader needs no MLA awarenessWhy absorption
DS-V2/V3 stores attention as four low-rank matrices. Absorbing them into standard Q/K/V at extraction time means:
Correctness
Key details:
kv_arope-K is MQA (one shared row for all KV heads, not per-head) — replicatednum_kvtimes when building absorbed K[nope | rope]; LARQL convention is[rope | nope]— absorption reorders symmetrically for both Q and Kabsorbed_forward_matches_referencetest: reference MLA forward pass vs absorbed path throughgqa_attention_asymmust agree within 1e-4 (f32 precision)Test plan
cargo test -p larql-inference -- gqa_attention_asym— 4 tests (shape, finite, sym-equivalence, causal)cargo test -p larql-vindex -- mla_absorb— 3 tests (forward equivalence, shapes, rope broadcast)cargo test -p larql-models— existing DS-V3 detection tests extended with new geometry accessorscargo test -p larql-vindex— 971 tests, 0 failures