Skip to content

Gemma-4 MTP: skip pooling/sampling epilogue for MTP graphs (fixes decode crash) + widen DKQ=512 TILE routing to all gqa_ratio#25

Open
PhilEgly wants to merge 2 commits into
AtomicBot-ai:feature/turboquant-kv-cachefrom
PhilEgly:fix/gemma4-mtp-build-pooling-epilogue
Open

Gemma-4 MTP: skip pooling/sampling epilogue for MTP graphs (fixes decode crash) + widen DKQ=512 TILE routing to all gqa_ratio#25
PhilEgly wants to merge 2 commits into
AtomicBot-ai:feature/turboquant-kv-cachefrom
PhilEgly:fix/gemma4-mtp-build-pooling-epilogue

Conversation

@PhilEgly

@PhilEgly PhilEgly commented Jun 6, 2026

Copy link
Copy Markdown

Summary

Gemma-4 MTP speculative decode crashed with a silent access violation (Windows 0xC0000005) on the first draft step, for both E4B and 12B targets. This PR fixes the crash and makes MTP speculative decode work end-to-end.

The crash was traced by binary-searching the decode path with flushed log probes on an E4B repro (E4B has gqa_ratio=2, so it clears the fattn DKQ=512 abort and reaches the real bug; the 12B's DKQ=512 abort masks it). Instrumentation proved the verify batch's set_inputs and graph_compute both complete with status 0 and decode() returns 0 — the crash is later, in MTP graph construction.

Root cause

The MTP speculative path calls llama_set_embeddings(ctx, true) so the main decode emits backbone hidden states for the drafter, and that flag stays set. llama_model::build_graph() runs a shared build_pooling() + build_sampling() epilogue on every graph. build_pooling() early-returns only if (!cparams.embeddings). With embeddings forced true, it runs against the MTP graph's t_embd (= h_post, the backbone projection — which has no pooling inputs) and dereferences bad memory.

A second latent bug rode along: build_sampling() would attach backend samplers to the MTP logits, which MTP never consumes (it does its own on-device argmax).

Fix

  • src/llama-model.cpp — gate the shared epilogue: if (params.gtype != LLM_GRAPH_TYPE_MTP) { build_pooling(...); build_sampling(); }. The MTP graph is self-contained and must not receive the main-decode pooling/sampling layers. This is the actual unblock.
  • ggml/src/ggml-cuda/fattn.cu — route all DKQ=512 cases to the TILE kernel. This generalizes the existing fix/cuda-mma-dkq512-fallback branch, whose gqa_ratio < 3 guard only covers E4B (gqa_ratio=2). Gemma-4 12B/26B/31B have gqa_ratio=8, fail that guard, and still hit the MMA GGML_ABORT. Routing all DKQ=512 to TILE covers them and the no-mask MTP cross-attention path. (Happy to rebase this onto fix/cuda-mma-dkq512-fallback if you'd prefer it land there.)
  • src/llama-graph.cpp — guard llm_graph_input_embd::set_input on embd != null (mirrors the existing can_reuse() check), so gemma4's per-layer-token input — which reuses this class but only allocates tokens — can't deref a null embd.
  • src/llama-context.cpp — re-reserve on a real set_embeddings() mode change (topology change); synchronize the main scheduler in decode_mtp_async before handoff so the MTP worker doesn't race the preceding decode's in-flight KV writes.
  • convert_hf_to_gguf.py — register Gemma4UnifiedAssistantForCausalLM as an alias of Gemma4AssistantForCausalLM so newer assistant checkpoints convert.

Note on branch overlap

The build_graph epilogue guard logically belongs with the MTP work on feature/gemma-mtp, and the fattn.cu change generalizes fix/cuda-mma-dkq512-fallback. I've targeted the default branch with the combined set for reviewability — happy to split into per-branch PRs if that fits your workflow better.

Validation

Clean rebuild (all diagnostics stripped), 4/4 requests each, server healthy, clean shutdown:

  • E4B (gemma-4-E4B-it-Q4_K_M + gemma-4-E4B-it-assistant.Q4_K_M): Hello! How can I help you today? 😊; 2, 3, 5; 17+26 => 43; draft acceptance 1/4–12/44.
  • 12B (gemma-4-12b-it-IQ4_XS + converted gemma-4-12B-it-assistant-f16): exercises both the fattn DKQ=512 fix (gqa_ratio=8) and the pooling fix; 17+26 => 43; draft acceptance 22/36–49/58. (Empty content on some prompts is thinking-mode token-budget behavior — 12B has thinking=1 and burns the 80-tok cap on <channel>thought — not a crash.)

Tested on RTX 5070 Ti (Blackwell sm_120), CUDA 13.3, but the epilogue guard is not GPU-specific — it would affect any Gemma-4 MTP target on any backend.

🤖 Generated with Claude Code

PhilEgly and others added 2 commits June 3, 2026 20:42
…ers == n_layer

The GEMMA4 hparam-loading path already disables KV reuse when shared_kv_layers
leaves no dedicated KV layers, but the GEMMA4_ASSISTANT path next to it does
not. For 26B/31B assistants where block_count == shared_kv_layers == 4, this
leaves hparams.n_layer_kv_from_start at 0 and downstream tensor-creation code
hits a 0-length vector subscript (visible on Windows debug-iterators as
"invalid vector subscript"; UB elsewhere).

Mirrors the existing GEMMA4 protection a few lines above. Reproduces with
google/gemma-4-26B-A4B-it-assistant converted via convert_hf_to_gguf.py.

Edge variants (E2B/E4B) and the new 2026-06-03 12B Unified assistant likely
have different shared_kv_layers values that avoid this edge case, which is
why current AtomicChat-published GGUFs do not exhibit it.
…ode crash)

Gemma-4 MTP speculative decode crashed with a silent access violation on the
first draft step. Root cause: the speculative path calls llama_set_embeddings(
ctx, true) so the main decode emits backbone hidden states for the drafter, and
that flag stays set. llama_model::build_graph() runs a shared build_pooling()/
build_sampling() epilogue on every graph; build_pooling only early-returns when
!cparams.embeddings, so it ran against the MTP graph's t_embd (= h_post, the
backbone projection, which has no pooling inputs) and dereferenced bad memory.

Fix: gate the epilogue on params.gtype != LLM_GRAPH_TYPE_MTP. The MTP graph is
self-contained (own logits / h_post / on-device argmax) and must not get the
main-decode pooling or backend-sampling layers. This also stops build_sampling
from attaching backend samplers to MTP logits, which MTP never consumes.

Supporting changes found during the investigation:
- ggml-cuda/fattn.cu: route ALL DKQ=512 cases to the TILE kernel. The previous
  gqa_ratio<3 guard returned BEST_FATTN_KERNEL_NONE for Gemma-4 12B/26B/31B
  (gqa_ratio=8), which aborts in the MMA dispatcher. TILE supports DKQ=512 with
  ncols2 fallback. Required for those targets to reach the MTP path at all.
- llama-graph.cpp: guard llm_graph_input_embd::set_input on embd != null
  (mirrors the existing can_reuse() check) so gemma4's per-layer-token input,
  which reuses this class but only allocates tokens, can't deref a null embd.
- llama-context.cpp: re-reserve on a real set_embeddings() mode change (topology
  change); synchronize the main scheduler in decode_mtp_async before handing off
  so the MTP worker doesn't race the preceding decode's in-flight KV writes.
- convert_hf_to_gguf.py: register Gemma4UnifiedAssistantForCausalLM as an alias
  of Gemma4AssistantForCausalLM so newer assistant checkpoints convert.

Validated: E4B and 12B MTP speculative decode now run end-to-end (4/4 requests
each, server healthy, draft tokens accepted).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant