This ropo contains code implemnentation of our paper MKA: Memory-Keyed Attention for Efficient Long-Context Reasoning
The main idea include:
MKA(3-path hierarchical memory attention)FastMKA(route-fused variant for speed)- CUDA extensions for fused routing + online softmax
- Reproducible training/evaluation scripts
Our code repo include follwing:
mka/layers/: PyTorch modules (MKAFullAttention,FastMKAAttention)mka/hf/: HuggingFace monkey patch support (Qwen/Llama styleself_attn)mka/cuda/: CUDA extensions (fastmka_attn, optionalfused_route_mka)mka/config/: optional YAML fields formemory_hierarchyetc.mka/utils/repro.py: global RNG seeding for reproducible runsscripts/: train/eval/benchmark entry pointsconfigs/: experiment configs
To get started, run following scripts
python -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
python scripts/train_wikitext2.py --config configs/qwen7b_fastmka.yaml-
FastMKA forward path in PyTorch:
- L1 local memory (
X) - L2 causal session summary (prefix EMA)
- Optional L3 retrieved memory
- Learned routing gate (
softmax(MLP(Q))) - Route-fusion before single KV projection
- L1 local memory (
-
MKA full path in PyTorch:
- Per-level attention over L1/L2/L3
- Soft mixture over outputs
-
CUDA design:
- Tiled QK score calculation
- Online max/denominator (
m,z) update - Fused route application before attention
- Causal masking support
-
HuggingFace direct patch path:
mka/hf/attention.py:HFFastMKAAttentionwrappermka/hf/patch.py: monkey patch over decoderself_attnscripts/train_hf_patch.py: train loop with patched attentionconfigs/hf_qwen_fastmka.yaml,configs/hf_llama_fastmka.yamlscripts/launch_dp_torchrun.sh: DP launch (torchrun)scripts/launch_tp_dp_accelerate.sh: TP+DP launch path (accelerate + HF TP)
cd mka/cuda
python build.py build_ext --inplace
cd ../..If build fails with CUDA_HOME environment variable is not set, export your CUDA path first, e.g.
export CUDA_HOME=/usr/local/cuda.
2.1 For Single GPU
python scripts/train_hf_patch.py --config configs/hf_qwen_fastmka.yaml2.2 For Multi-GPU DP
bash scripts/launch_dp_torchrun.sh configs/hf_qwen_fastmka.yaml 42.3 For Multi-GPU TP+DP
- Set
tp_sizein config (>1). - Launch:
bash scripts/launch_tp_dp_accelerate.sh configs/hf_qwen_fastmka.yaml 4Notes:
- TP relies on HuggingFace native
tp_plan="auto"support for the model/version. - Dependencies use lower bounds in
requirements.txt(adjusttorchfor your CUDA wheel). - Training throughput (
train_throughput_tok_s) is measured afterwarmup_steps(see YAML) and includes forward + backward + optimizer. Inference prefill/decode (forward-only) is reported byscripts/bench_inference_metrics.py. - FastMKA CUDA kernel is used automatically when:
- extension
fastmka_cudais available, - tensor is CUDA,
head_dim <= 256,- no extra additive attention mask is required.
- extension
- YAML:
seed,warmup_steps(exclude cold-start from timed throughput), optionaldeterministic: true(slower, stricter cudnn). train_hf_patch.py: logstrain_mean_loss,train_total_elapsed_s,train_throughput_tok_s(post-warmup),peak_gpu_memory_gb, optional--eval-pplfor validation PPL.bench_inference_metrics.py:prefill_tok_s,decode_tok_s, per-phase peak GPU memory,kv_cache_bytes_*frompast_key_values. HBM bandwidth is not available from PyTorch alone; use Nsight /nvidia-smi dmonon the host.
Block-MKA (§4.2) maps memory to compute tiers: L1 on-chip SRAM (tiled attention, online softmax with running max and partition sum); L2 HBM (activations, Q/K/V, fused KV cache); L3 DRAM (vectorized hash, chunk recall). FastMKA (Algorithm 2) route-fuses L1/L2/(L3) into one hidden representation, then one KV projection and one causal attention — the dominant data path is fused activations → KV on HBM → attention (see detials in our paper Tables 4–6).
YAML memory_hierarchy records these tiers for reproducibility (mka/config/memory_hierarchy.py). Older keys (hbm_enabled, dram_staging, ssd_tier_path) still parse as aliases.
Scripts vs metrics
- Training:
train_hf_patch.py— tokens/s includes backward + optimizer - Inference-style:
bench_inference_metrics.py— prefill vs decode, KV bytes, per-phase GPU peak memory. - HBM bandwidth: use Nsight Compute / vendor tools, not PyTorch alone.
References for implementation quality
- FlashAttention-2 / SDPA for L1+L2 fused attention paths.
- Paged KV / disk offload patterns (vLLM, FlashInfer-class) when extending L3 or spill.
- ZeRO-Infinity–style paths for
ssd_spill_pathexperiments.
For evaluation of LongBench and RULER. Use:
scripts/run_longbench.shfor LongBench workflowscripts/run_ruler.shfor RULER workflow
@inproceedings{mka2026,
title = {MKA: Memory-Keyed Attention for Efficient Long-Context Reasoning},
author = {Dong Liu and Yanxuan Yu and Ben Lengerich and Ying Nian Wu},
booktitle = {Proceedings of the ACM International Conference on Computing Frontiers (CF '26)},
year = {2026}
}