SuperSonic-MoE is a blazing-fast Mixture-of-Experts (MoE) implementation optimized for NVIDIA Hopper and SM100 GPUs, leveraging CuTeDSL and Triton.
The current SM100 frontier is green on branch race-fix-paddle. The active path is:
DeepEP topk metadata
-> route-level padding
-> zero-materialization FP8 up-proj (A_idx, no x_gathered)
-> fused GEMM + SwiGLU + z FP8 epilogue quant
-> FP8 down-proj
-> FP8-C-load GemmDGated backward
-> iso32 dz dual quant
-> TMA reduce-add wgrad into framework main_grad
Latest reference-shape benchmark (T=8192,H=3072,I=1536,E=8,K=8, Target GPU, nsys GPU-projection):
| Metric | Current value | Notes |
|---|---|---|
| FP8 busy time | 2659.8 µs/iter | reports/fresh_benchmark_ws1/ |
| MFU | 46.51% | denominator: 4500 TFLOPS FP8 peak |
| Peak measured MFU | 51.61% | T8192-H4096-I4096-E8-K8 |
| Speedup vs true BF16 (same codebase) | 1.63x | 2659.8 vs 4346 µs; BF16 verified zero FP8 kernels |
| Speedup vs historical S53 cuBLAS BF16 | ~1.84x | S53 3644 µs is PyTorch-native, no Paddle proxy |
| Precision | cos >= 0.997, RRMSE < 7.6% | output/dx/ds/dw1/dw2 suites |
| Determinism | bit-exact repeated fwd/bwd | tests/fp8_frontier_determinism_test.py hard gate |
Core MFU formula:
F = 18 * TK * H * I
MFU = F / (busy_seconds * peak_FLOPs_per_second)
Reference:
TK = 8192 * 8 = 65536
F = 18 * 65536 * 3072 * 1536 = 5.566e12 FLOPs
ideal @4500 TFLOPS = 1237 µs
measured = 2659.8 µs
MFU = 46.51%
| Parameter | Value | Derivation |
|---|---|---|
| α | 7.491 × 10⁻⁹ | = 18/(P · η_tc), η_tc = 53.4% |
| β | 9.520 × 10⁻⁷ | L2 miss penalty per (access × refetch-width) |
| TK₀ | 39 747 | ≈ (L2/E) / (tile_M · tile_N · b_elem) · tile_M = (96MB/8) / 64KB · 256 ≈ 49K (fitted: 40K) |
| γ | 21.72 µs | Per-expert TMA descriptor + metadata |
Fit: MAPE = 3.79%, MFU MAE = 1.49%, N = 104 (uniform 4E × 4HI × 6TK grid + 16 large-TK).
| # | Term | Physical origin | Why this functional form |
|---|---|---|---|
| 1 | α · TK · H · I | 6 CUTLASS GEMMs totalling 18 · TK · H · I FLOPs. η_tc < 1 due to persistent-scheduler per-tile overhead and DGated epilogue register pressure (168 regs → 1 block/SM). | Proportional to F/P at constant η_tc. Shape-invariant: same tile config (256², cluster 2×1) across all (H, I). |
| 2 | β · TK · D · ln(1 + TK/TK₀) | Weight-tile L2 reuse breakdown. Scheduler cycles TK/tile_M tiles; when this exceeds L2 capacity per expert (~192 sets), multi-stream SM contention drives miss rate proportional to ln(TK/TK₀). Each miss refetches D bytes. | D = max(H, 2I): widest tile dimension among the 6 GEMMs' B-tensors. TK: total L2 accesses ∝ token count. ln(1+x): smooth onset; empirical consensus for shared-cache thrashing under randomized multi-tenant load. |
| 3 | γ · E | Per-expert one-time: TMA descriptor, weight-cache version check, routing histogram. O(1) per expert, independent of TK, H, I. | E=8: 174 µs (< 7%); E=128: 2784 µs (dominant only when TK < 50K). |
| Phase | Regime | Limiting behavior |
|---|---|---|
| Rise | TK ≪ TK₀ | MFU ≈ 18 · H · I · TK / (P · γ · E), linear ↑ |
| Peak | TK ~ TK₀ | Maximum, ∂MFU/∂TK = 0 |
| Decay | TK ≫ TK₀ | MFU ∝ 1/ln(TK), logarithmic ↓ |
| Shape | Measured peak | Model@peak | Measured @4M | Model@4M |
|---|---|---|---|---|
| Reference | 45.7% @65K | 46.3% | 38.3% | 38.3% |
| H=4096 I=2048 | 50.5% @65K | 48.5% | 40.6% | 41.0% |
| H=4096 I=4096 | 52.4% @33K | 49.1% | 43.6% | 41.0% |
| H=6144 I=3072 | 52.9% @33K | 50.5% | 44.3% | 44.6% |
Left column: MFU vs TK with model curves (solid) and nsys-measured points (markers). Dashed line = steady-state.
Right column: MFU contour in (log₁₀TK, E) space; red curve = TK* (peak MFU trajectory).
Read first:
| Priority | Document | Purpose |
|---|---|---|
| 1 | HANDOFF.md |
canonical current project state, lessons, next plan |
| 2 | reports/sonic_moe_fp8_frontier_newcomer_guide.md |
standalone FP8/CuTe/SuperSonic-MoE newcomer guide + expert Q&A |
| 3 | docs/gemm_dgated_fp8cload.md |
GemmDGatedFP8CLoad design: Int16 trick, dSwiGLU, FP8 dequant, zero-materialization |
| 4 | docs/expert_interleave_weight_layout.md |
Expert Interleave weight layout: 5-layer benefit analysis + originality |
| 5 | reports/sonic_moe_comprehensive_analysis.md |
broad technical analysis and roofline/precision/perf tables |
| 6 | reports/fresh_benchmark_ws1/README.md |
latest sweep data and MFU fit |
| 7 | reports/reference_shape_ncu/README.md |
NCU resource breakdown for the 6 GEMMs (reference shape) |
Important current insights:
GemmDGatedFP8CLoadSm100ZeroMatis the main structural bottleneck: 168 regs/thread, ~42% tensor-pipe. Do not directly add dz quant loops to this epilogue.- TMA reduce-add is a register/performance optimization, not higher-precision accumulation.
- iso32 dz dual quant is measured safe for current reference-shape
dz; monitorlog2(block_amax/row_amax)before generalizing. SonicMoEMlpNode.step()must run beforeoptimizer.step()because it flushes native CUTLASS wgrad layout into frameworkmain_grad.
ISO32 (32×32 block) FP8 weight quantization stores one buffer per weight instead of two transposed copies by exploiting the byte-identical transpose invariant of isotropic block scaling. Forward and backward GEMM kernels consume the same physical buffer via zero-copy stride views. Controlled by SONIC_MOE_FP8_ISO32_WEIGHT=1 (default OFF).
Memory saving: 48.5% of weight FP8 cache (108 MiB at E=8, H=3072, I=1536).
Precision: Identical to baseline 1×32 path vs BF16 golden (ratio=1.0000, verified across 9 shapes).
Performance: −0.5% GPU-projection (1 fewer unique kernel, identical GEMM signatures).
Left: Baseline pair-kernel path (4 separate FP8 weight buffers). Right: ISO32 single-buffer path (2 buffers + stride views). Dashed boxes = eliminated allocations.
Full DeepEP pipeline: Router → All-to-All dispatch → FP8 FFN (with persistent weight cache) → Combine → reverse A2A. Three gradient paths: dx (autograd), ds (router gate), dw (main_grad accumulate).
Validation report: reports/iso32_weight_validation_report.md
- NVIDIA Hopper GPUs (H100, H200) or SM100 GPUs (Target GPU, Target GPU, Target GPU)
- CUDA 12.9+ (13.0+ for Target GPU)
- Python 3.12+, PyTorch 2.7+
USE_QUACK_GEMM=1for SM100 kernels
The default code path is the fused-v2 frontier (FP8 epilogue blockscaled quant + fused-gated up-proj + TMA reduce-add wgrad + CUDA fused topk-metadata kernel). Most flags below are already set to their production-recommended defaults; they are documented here so other agents can quickly reason about the active path.
The fused CUDA topk→metadata kernel (sonicmoe.ernie_compat.deepep_topk_metadata_cuda) is JIT-compiled on first import and dispatched automatically by deepep_topk_to_sonic_metadata. There is no env flag to gate it: when the import succeeds it is always used; when the build fails the code falls back to the Triton-only path. To force the legacy path for A/B comparison, delete the build directory before import:
rm -rf $TORCH_EXTENSIONS_DIR/sonicmoe_deepep_topk_metadata_cuda # (or matching paddle build dir)Active production flags (set these for training):
| Env flag | Default | What it gates |
|---|---|---|
SONIC_MOE_FP8_MODE |
unset | Master switch. perf = full FP8 frontier; mem = FP8 + stage-wise memory reuse; unset = BF16. |
USE_QUACK_GEMM |
unset | Required on SM100 for CuTeDSL FP8 GEMMs. |
SONIC_MOE_FP8_WGRAD |
unset | FP8 weight gradients (otherwise BF16 wgrad). Always 1 for production. |
SONIC_MOE_FP8_ASSUME_ALIGNED |
0 |
Skip runtime padding-check H2D sync. Set 1 for zero-sync training (requires aligned token counts). |
SONIC_MOE_FP8_ISO32_WEIGHT |
0 |
ISO32 weight cache unification: stores ONE fp8 buffer per weight (saves 48.5% weight cache memory). |
SONIC_MOE_FP8_RECOMPUTE_Z |
0 |
Skip storing z_fp8 in fwd; rerun up-proj in bwd. Saves ~213 MiB/layer, costs ~5–15% extra time. |
SONIC_MOE_STAGEWISE_MEMORY |
0 |
Free activations eagerly between stages. Saves ~1.0–1.5 GB at reference shape, costs 3–5%. |
SONIC_MOE_CACHE_DIR |
~/.cache/sonicmoe |
JIT compile-cache directory. |
Hardcoded defaults (always-on in the frontier path, no need to set):
| Flag | Value | Notes |
|---|---|---|
SONIC_MOE_FP8_FUSED_GATED |
1 |
Fused SwiGLU+quant in up-proj epilogue. Non-fused path is deprecated. |
SONIC_MOE_FP8_EPILOGUE_QUANT |
1 |
FP8 quant inside GEMM epilogue. Always active when FP8 enabled. |
SONIC_MOE_FP8_FUSED_SWIGLU_QUANT |
1 |
Fused SwiGLU+rowquant on y1 path. Always active. |
SONIC_MOE_FP8_SAVE_Z_FP8 |
1 |
Save z as FP8 for backward (vs BF16). Always active. |
Deprecated / do-not-use:
| Flag | Reason |
|---|---|
SONIC_MOE_FP8_RECOMPUTE_OPT_B |
Produces illegal-instruction on non-uniform routing. Will be removed. |
SONIC_MOE_FP8_WGRAD_BETA_ACCUM |
Legacy fused-beta epilogue (86 regs). TMA reduce-add (50 regs) is strictly better. |
SONIC_MOE_FP8_UPPROJ_EPILOGUE_PRECISION |
Always fp8-blockscaled. No other value tested or supported. |
SONIC_MOE_FP8_DOWNPROJ_MAINLOOP_PRECISION |
Always fp8-blockscaled. |
SONIC_MOE_FP8_DOWNPROJ_WEIGHT_PRECISION |
Always bf16. FP8 weight not supported in down-proj mainloop. |
SONIC_MOE_FP8_FUSED_ZY1_QUANT |
Experimental dual z+y1 quant kernel. Slower than separated path on production shapes. |
SONIC_MOE_FP8_BLOCKSCALED_EXPERT_CAPACITY |
Benchmark-only override. Never set in production. |
Production launcher:
USE_QUACK_GEMM=1 \
SONIC_MOE_FP8_MODE=perf \
SONIC_MOE_FP8_WGRAD=1 \
SONIC_MOE_FP8_ASSUME_ALIGNED=1 \
TRITON_PTXAS_PATH=/usr/local/cuda-13.0/bin/ptxas \
python train.pyThe race-fix-paddle branch integrates SonicMoE into Paddle via paddle.compat.enable_torch_proxy. Production entry point: SonicMoEMlpNode.
from sonicmoe.ernie_compat import SonicMoEMlpNode
node = SonicMoEMlpNode(experts, n_experts=E, hidden_size=H, intermediate_size=I)
# ── Cold start: first fwd+bwd triggers JIT compilation (~42s) ──
# For explicit warmup before training:
# from sonicmoe.jit_warmup import warmup_jit
# warmup_jit(E=8, H=3072, I=1536, device="cuda")
for step in range(num_steps):
for mb in microbatches:
out = node(x, tokens_per_expert, dispatched_indices, dispatched_probs)
out.backward(grad)
node.step() # MUST run BEFORE optimizer.step():
# converts native CUTLASS [E,2I,H] → framework split-half
# [E,H,2I] in-place into expert.weight.main_grad,
# which the optimizer then reads.
optimizer.step()
optimizer.clear_grad()main_grad is lazy-allocated on first backward (S74 follow-up); inference / warmup-only flows pay zero main_grad memory. Weight cache invalidation is automatic via (data_ptr, _inplace_version(w)) keys — no manual clear_* needed after in-place optimizer updates.
| Phase | What happens | Time |
|---|---|---|
| Import | CUDA topk metadata kernel compiled | ~4s |
| 1st fwd+bwd | CuTe CUTLASS GEMM + all Triton kernels JIT compiled | ~42s |
| 2nd fwd+bwd | All caches hit, steady-state | 0.05s |
| New seqlen | CuTe GEMM: 0 recompile (dynamic dims via mark_layout_dynamic). Triton: ~2.5s per new TK (cached in ~/.triton/cache across sessions) |
0-2.5s |
| After optimizer.step() | Call node.step() → flushes native-layout wgrad to per-expert main_grad, clears weight/FP8/topk caches |
<1ms |
Three tiers of caching, each with different invalidation strategies:
| Cache | Key includes | Invalidated by | Max size |
|---|---|---|---|
CuTe compile cache (_COMPILE_CACHE*) |
Static model dims (H, I, E, dtype, tile config) only. No token counts. | Never (persistent for model lifetime) | Unbounded (typically 3-8 entries) |
Fast-path runtime cache (_GEMM_FAST_PATH*) |
Exact problem shape (total_M/K + all tensor dims) | Auto-eviction at 64 entries | 64 |
FP8 weight cache (_WEIGHT_CACHE etc.) |
data_ptr + _inplace_version + shape + stride |
node.step() / invalidate_weight_caches() |
8 per cache |
Triton JIT cache (~/.triton/cache/) |
Full kernel source hash | rm -rf ~/.triton/cache |
Unbounded (disk) |
Design principle: compile_key contains only static model dimensions — never TK, total_M, total_K, capacity, or any token-count-dependent value. Dynamic token dimensions are handled at runtime via CuTe's mark_layout_dynamic. This ensures zero CuTe recompilation when batch size or routing distribution changes.
Production deploys 1 process per GPU; every rank on every node sees the
same build/ directory and SONIC_MOE_CACHE_DIR over GPFS. The JIT layer
is fully race-safe under concurrent cold start across ranks:
| Layer | Race-safety mechanism |
|---|---|
C++ extensions (sonicmoe/jit.py) |
Cross-process FileLock at <parent>/.{module_name}.lock (stable inode survives build/<module>/ wipe). After lock acquisition the worker first tries _try_import_prebuilt — directly dlopen-imports <build_dir>/<name>/<name>.so (PYBIND11) bypassing paddle's racy load() whenever the artifact already exists. Only the first rank actually compiles; later ranks fast-path import. |
Triton kernels (~/.triton/cache, or $TRITON_CACHE_DIR) |
Per-key subdirs; same-key concurrent writes use Triton's atomic temp+rename. |
Quack / CuTe autotuner (autotuner.py:check_disk_cache) |
sha256(VERSION + tuning_key + configs) → per-key file. Different shapes write to different files; same-key writes are atomic. Once any rank has tuned a (dtype, K, config) combination, all other ranks read the same JSON on the next iteration. |
paddle.enable_compat() ordering |
_get_cpp_function re-arms the torch.utils.hipify proxy blocker (install_quack_paddle_compat()) inside the FileLock immediately before cpp_extension.load(), so the patches are live regardless of whether the consumer imported sonicmoe before or after paddle.enable_compat(). |
Heterogeneous-shape concurrent cold starts (e.g. GPU0 hits total_K=4096,
GPU1 hits total_K=8192 simultaneously on a fresh cache) are explicitly
covered by tests/ops/test_jit_concurrent_heterogeneous.py and gated in CI
under phase jit-concurrent.
Recommended deployment: pre-warm once per cluster, then point every rank at the shared cache:
python -m sonicmoe.cli.warmup --E 32 --H 3072 --I 1536 \
--cache-dir /gpfs/.../sonicmoe_jit_e32_h3072_i1536
# every rank
export SONIC_MOE_CACHE_DIR=/gpfs/.../sonicmoe_jit_e32_h3072_i1536Skipping pre-warm is also safe: ranks will JIT-compile in parallel under the locks above; the cluster pays the cold-start cost once and converges to steady-state on iteration ≥ 2.
- Call
paddle.enable_compat()(orpaddle.compat.enable_torch_proxy(scope={"sonicmoe", "quack", "triton"})) beforeimport sonicmoe. The reverse order is also tolerated —_get_cpp_functionre-arms patches lazily — but the recommended order avoids a duplicate proxy install on the hot path. - Set
SONIC_MOE_CACHE_DIRto a shared GPFS path so all ranks reuse Triton/Quack disk caches; otherwise each rank caches under~/.cache/sonicmoe. - Construct one
SonicMoEMlpNodeper fused expert group (see Paddle Integration above). Callnode.step()BEFOREoptimizer.step()each microbatch; this flushes native-layout wgrad into per-expertmain_grad. - For diagnostics:
SONIC_MOE_VALIDATE_DISPATCH=1enables an O(T) per-row uniqueness check ondispatched_indicesbefore kernel launch (catches malformed routing in custom dispatchers).
- paddle.distributed.launch on paddlejob: WHITELIST env, never denylist. The paddlejob cluster exports
PADDLE_TRAINERS=4 IPs,DISTRIBUTED_TRAINER_ENDPOINTS(32 entries),POD_*,EKS_POD_*,PADDLE_CURRENT_ENDPOINT,PADDLE_CLUSTER_TRAIN,PADDLE_IS_LOCAL=0, etc. Any one of these makes the launcher silently enter multi-NODE rendezvous mode and block forever (no error, no children, launcher inRstate).tools/ci/multicard_smoke.pybuilds env from a strict prefix whitelist (PATH/LD_/HOME/USER/LANG/CUDA_/NVIDIA_/TRITON_/SONIC_MOE_/FLAGS_/NCCL_/GLOG_/OMP_/PYTHON…). Place(gpu:N) is not supportedis almost always a lazy device-pool init bug. The pool only registers the place named byFLAGS_selected_gpus; any other place errors. Eager-allocate a 1-element tensor right afterpaddle.device.set_device(...)to force pool registration before any async path (autograd backward,paddle.libraryproxies inside quack JIT,paddle.tensor.random.gaussian) hits it. Same root cause as the production crash fromquack.autotuner._gpu_warmup._FP8Config()snapshotsis_fp8_active()at construction, not at use. Always construct it INSIDE thewith enable_fp8(True):block; otherwise a prior test/code path'senable_fp8(False)lingering state leaks and the wgrad path silently falls back to BF16. Belt-and-braces: setSONIC_MOE_FP8_MODE=perfBEFOREimport sonicmoe.- Multi-process JIT cache on shared GPFS:
sonicmoe.jitusesFileLockon a stable parent directory (NOT insidebuild_directory, which paddle wipes mid-cycle). Triton + Quack disk caches naturally support multi-rank reuse; sentinel ({cache_root}/warmup_sentinel.json) gates cold compile. Cross-rank shape divergence (rank 0 sees shape A, rank 1 sees shape B) is supported because each rank takes the file lock per-key. - quack import path:
/usr/local/bin/python(the defaultpythonon this host) lacks thequacksite-package; it lives at/root/paddlejob/share-storage/gpfs/system-public/zhangyichen/sonicmoe_deps/quack.tests/conftest.pyinjects this intosys.pathautomatically;tools/ci/jit_bench.py::_run_subprocessandtools/ci/multicard_smoke.WORKER_BODYdo the same for spawned workers. - Hipify proxy install order matters. paddle's torch-proxy intercepts
torch.utils.hipify.hipify_pythonlookups and raisesKeyError. Fix: pre-import the realtorch.utils.hipify[.hipify_python]and add it topaddle.compat.extend_torch_proxy_blocked_modules. Re-arm inside the JIT lock so the patch applies regardless ofimport sonicmoevspaddle.enable_compat()ordering. - SM100 ptxas mismatch. Bundled Triton 3.5 ptxas does not recognize
SM100. Tests/CI globally exportTRITON_PTXAS_PATH=/usr/local/cuda-13.0/bin/ptxas(handled intests/conftest.py::_ensure_sm100_ptxas). Production deployment should set this if running on Target GPU. dispatched_indicesper-row uniqueness contract. The kernel (sonicmoe/ernie_compat/deepep_topk_metadata_cuda/kernel.cuL315-333) assumes each row ofdispatched_indicescontains pairwise-distinct expert ids. Real DeepEP dispatch always satisfies this; custom test fixtures must. The optional contract validator atdeepep_metadata.py(gated bySONIC_MOE_VALIDATE_DISPATCH=1) catches violations early.- paddle's
cpp_extension.load()writes a wrapper.pywhose__bootstrap__references*_pd_.so. In this paddle build the file is plain*.soand the wrapper is non-functional; importing the.sodirectly viaspec_from_file_location(PYBIND11 handles symbol export) is the correct fast-path. - Stable lock paths matter. paddle's
load()may wipe itsbuild_directorymid-cycle. The FileLock must live in the parent dir so the lock inode survives.
| Gradient | Mechanism | Verified |
|---|---|---|
| dx (d/d hidden_states) | Paddle autograd through _SonicMoEDeepEPFunc.backward |
cos=0.9975 |
| ds (d/d dispatched_probs) | _GatherRouterScores PyLayer with custom Triton scatter (no CUB cascade) |
cos=0.9971–0.9973 |
| dw1, dw2 | CUTLASS wgrad accumulates directly into the per-instance fused [E, 2I, H] / [E, H, I] native buffer (lazy-allocated on first backward); node.step() performs the in-place native→framework split-half layout conversion into expert.weight.main_grad before optimizer.step() reads it. |
cos=0.9975 / 0.9971 |
| N | K | E | I | out | dx | dw1 | dw2 |
|---|---|---|---|---|---|---|---|
| 128 | 4 | 4 | 384 | 0.9979 | 0.9975 | 0.9975 | 0.9972 |
| 128 | 8 | 8 | 384 | 0.9979 | 0.9975 | 0.9975 | 0.9971 |
| 512 | 4 | 8 | 1536 | 0.9979 | 0.9975 | 0.9975 | 0.9972 |
| 512 | 8 | 8 | 1536 | 0.9979 | 0.9975 | 0.9975 | 0.9972 |
| 1024 | 8 | 8 | 1536 | 0.9979 | 0.9975 | 0.9975 | 0.9972 |
| 256 | 8 | 32 | 1536 | 0.9979 | 0.9975 | 0.9975 | 0.9971 |
ds gradient verified via test_cold_start_e2e.py: cos=0.9972 across all 6 shapes (1024/8192/4096/2048/512/16384 tokens).
All cosine > 0.99, RRMSE < 7.6%. Shapes include E=32 (production), varying topk (4/8), small/large token counts.
Current canonical data lives in reports/fresh_benchmark_ws1/. The table below supersedes
older Session-65-only summaries for frontier-level comparisons.
| Shape | QuACK BF16 (µs) | FP8 frontier (µs) | Speedup vs QuACK BF16 | FP8 MFU |
|---|---|---|---|---|
| T=1024 H=3072 I=1536 E=8 K=8 | 540.0 | 566.0 | 0.95x | 27.32% |
| T=2048 H=3072 I=1536 E=8 K=8 | 847.7 | 870.1 | 0.97x | 35.54% |
| T=4096 H=3072 I=1536 E=8 K=8 | 1533.4 | 1459.1 | 1.05x | 42.39% |
| T=8192 H=3072 I=1536 E=8 K=8 | 2942.5 | 2659.8 | 1.11x | 46.51% |
| T=16384 H=3072 I=1536 E=8 K=8 | 6022.1 | 5224.9 | 1.15x | 47.35% |
| T=8192 H=4096 I=4096 E=8 K=8 | 10894.7 | 8521.7 | 1.28x | 51.61% |
Historical S53 BF16 numbers were cuBLAS/PyTorch-style and slower than the current QuACK BF16 path; relative to that historical baseline, reference-shape FP8 is ~1.37x. Always label which BF16 baseline is used.
TMA reduce-add remains part of the current default path. It replaced legacy
D=A@B+1.0*C wgrad accumulation (86 regs/thread) with TMA store-side ADD
(~50 regs/thread), saving 2-4% end-to-end depending on E. It does not raise
accumulation precision; main_grad remains fp32 and determinism is test-gated.
See HANDOFF.md for full kernel breakdown, memory notes, and next-step priorities.
| File | Purpose |
|---|---|
sonicmoe/ernie_compat/mlp_node_v2.py |
SonicMoEMlpNode: production MlpNode with .forward(), .step(), .warmup() |
sonicmoe/jit_warmup.py |
warmup_jit(E, H, I): pre-compiles all CuTe + Triton kernels |
sonicmoe/quack_utils/blockscaled_fp8_gemm.py |
FP8 GEMM wrappers (CUTLASS + Triton quant), cache key design |
sonicmoe/quack_utils/swiglu_triton.py |
Fused SwiGLU Triton kernels (5 production + 2 legacy bf16 variants) |
sonicmoe/quack_utils/_validate.py |
Low-overhead input validation (dtype/stride/shape, zero GPU sync) |
sonicmoe/ernie_compat/deepep_metadata.py |
DeepEP topk → SonicMoE routing metadata conversion |
sonicmoe/functional/__init__.py |
_UpProjection, _DownProjection autograd Functions |
| Test | What it validates | Run command |
|---|---|---|
test_cold_start_e2e.py |
Cache clear → JIT → 6-shape precision (out/dx/ds/dw1/dw2) | CUDA_VISIBLE_DEVICES=2 python tests/ops/test_cold_start_e2e.py |
test_mlpnode_correctness_large.py |
Topk kernel bug-fix regression — 9 cases incl. SEQ=16K (TK=131072), skew/extreme/holes | CUDA_VISIBLE_DEVICES=7 python tests/ops/test_mlpnode_correctness_large.py |
test_jit_optimization.py --quick |
Correctness (cos>0.99), zero JIT recompile, memory | CUDA_VISIBLE_DEVICES=0 python tests/ops/test_jit_optimization.py --quick |
test_mlpnode_precision.py |
Multi-topk precision audit | CUDA_VISIBLE_DEVICES=0 python tests/ops/test_mlpnode_precision.py |
bench_mlpnode_mem.py |
E=32 fwd+bwd memory benchmark (reference shape) | CUDA_VISIBLE_DEVICES=1 python tests/ops/bench_mlpnode_mem.py |
bench_wgrad_epilogue.py |
A/B wgrad epilogue benchmark (TMA add vs fused beta) | CUDA_VISIBLE_DEVICES=2 python tests/ops/bench_wgrad_epilogue.py |
bench_mlpnode_topk_nsys.py |
nsys GPU-projection benchmark | Wrap with nsys profile --resolve-symbols=false |
| Priority | Resource | Path |
|---|---|---|
| 1 | Handoff (current state) | Root HANDOFF.md — canonical current state, lessons, insights, and next plan |
| 2 | PaddleFleet migration | docs/PADDLEFLEET_MIGRATION_S74.md — stream patch, node.step() ordering, lazy main_grad, Fleet's pre-fused-weight integration path |
| 3 | This README | Root README.md — architecture, cache design, training loop, test matrix |
| 4 | Newcomer Guide | reports/sonic_moe_fp8_frontier_newcomer_guide.md — FP8/CuTe/SonicMoE basics through expert Q&A, MFU math (Section 7) |
| 5 | GemmDGated Design Doc | docs/gemm_dgated_fp8cload.md — Int16 trick, dSwiGLU SASS decomposition, FP8 blockscaled dequant, zero-materialization |
| 6 | Expert Interleave Design Doc | docs/expert_interleave_weight_layout.md — full-stack benefit analysis of gate/up interleaved weight layout |
| 7 | Engineering Log | reports/fp8_upgrade/engineering_log.md — historical lesson log only; current-state correction block at top |
| 8 | Environment | /root/paddlejob/share-storage/gpfs/system-public/panzhaowu/env.md — machine setup, Paddle compat pitfalls, perf methodology |
Project state (clean handoff, 2026-05-09): branch
race-fix-paddle. FP8 frontier remains green: precision (out/dx/dw1/dw2/ds) cos≥0.997, determinism hard-gated, route-level padding active, TMA reduce-add wgrad default,node.step()MUST precedeoptimizer.step(),main_gradlazy-allocated. Latest reference-shape FP8 busy time is 2659.8 µs/iter (46.51% MFU, 1.63x vs true BF16). BF16 baseline (4346 µs) now verified clean via nsys (zero FP8 kernels). ReadHANDOFF.mdbefore any kernel work; the current P0 is structural dgrad1 optimization (live-range shortening / fission), NOT direct dz-quant epilogue fusion. MXFP8 128-row alignment waste analysis completed — see/panzhaowu/bkup/mxfp8_alignment_waste_analysis.pdf.
Quick-validate the frontier before resuming work:
source .runenv.sh
bash tests/run_regression.sh # full regression incl. determinism gate
# OR a fast spot-check:
python -m pytest tests/fp8_frontier_determinism_test.py -v
python -m pytest tests/ops/test_mlpnode_multilayer.py tests/ops/test_mlpnode_correctness_large.py \
tests/ops/test_colwise_quant.py tests/ops/test_rowwise_quant.py tests/ops/test_fused_quant.py
python tests/ops/test_mlpnode_precision.py # 6-shape topk precision auditimport torch
from sonicmoe import MoE, SonicMoEConfig
from sonicmoe.enums import ActivationType
moe = MoE(num_experts=8, num_experts_per_tok=8, hidden_size=3072,
intermediate_size=1536, activation_function=ActivationType.SWIGLU,
add_bias=False, std=0.02).to(device="cuda", dtype=torch.bfloat16)
x = torch.randn(8192, 3072, device="cuda", dtype=torch.bfloat16)
cfg = SonicMoEConfig(use_fp8=True, use_quack_gemm=True)
with cfg.activate():
output, aux_loss = moe(x, use_fp8=True)Strict-baseline regression runner — every core mechanism (precision, multilayer/PP, quant kernels, JIT cold/warm/reload/reuse, nsys perf, multi-card) is measured and gated against tools/ci/baselines.json:
# fast pre-commit suite (~2 min on warm cache)
bash tools/ci/run_core_tests.sh --fast
# full suite — also runs jit-cold (~10 min), perf gate (nsys), multi-card
bash tools/ci/run_core_tests.shInstall the pre-commit hook once per clone:
git config core.hooksPath .githooksOffline pre-warm the JIT cache (Triton + Quack disk caches + sentinel) so production training skips the multi-minute first-loss cost:
python -m sonicmoe.cli.warmup --E 32 --H 3072 --I 1536 \
--cache-dir /nfs/sonicmoe_jit_e32_h3072_i1536
# then export SONIC_MOE_CACHE_DIR=/nfs/... on every training rank@misc{guo2025sonicmoeacceleratingmoeio,
title={SonicMoE: Accelerating MoE with IO and Tile-aware Optimizations},
author={Wentao Guo and Mayank Mishra and Xinle Cheng and Ion Stoica and Tri Dao},
year={2025},
eprint={2512.14080},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2512.14080},
}Apache License 2.0 - see LICENSE.




