Skip to content

PFCCLab/supersonic-moe

 
 

Repository files navigation

SuperSonic-MoE: Accelerating MoE with IO and Tile-aware Optimizations

arXiv

SuperSonic-MoE is a blazing-fast Mixture-of-Experts (MoE) implementation optimized for NVIDIA Hopper and SM100 GPUs, leveraging CuTeDSL and Triton.

image image

Current FP8 Frontier Snapshot (2026-05-07)

The current SM100 frontier is green on branch race-fix-paddle. The active path is:

DeepEP topk metadata
  -> route-level padding
  -> zero-materialization FP8 up-proj (A_idx, no x_gathered)
  -> fused GEMM + SwiGLU + z FP8 epilogue quant
  -> FP8 down-proj
  -> FP8-C-load GemmDGated backward
  -> iso32 dz dual quant
  -> TMA reduce-add wgrad into framework main_grad

Latest reference-shape benchmark (T=8192,H=3072,I=1536,E=8,K=8, Target GPU, nsys GPU-projection):

Metric Current value Notes
FP8 busy time 2659.8 µs/iter reports/fresh_benchmark_ws1/
MFU 46.51% denominator: 4500 TFLOPS FP8 peak
Peak measured MFU 51.61% T8192-H4096-I4096-E8-K8
Speedup vs true BF16 (same codebase) 1.63x 2659.8 vs 4346 µs; BF16 verified zero FP8 kernels
Speedup vs historical S53 cuBLAS BF16 ~1.84x S53 3644 µs is PyTorch-native, no Paddle proxy
Precision cos >= 0.997, RRMSE < 7.6% output/dx/ds/dw1/dw2 suites
Determinism bit-exact repeated fwd/bwd tests/fp8_frontier_determinism_test.py hard gate

Core MFU formula:

F = 18 * TK * H * I
MFU = F / (busy_seconds * peak_FLOPs_per_second)

Reference:
  TK = 8192 * 8 = 65536
  F = 18 * 65536 * 3072 * 1536 = 5.566e12 FLOPs
  ideal @4500 TFLOPS = 1237 µs
  measured = 2659.8 µs
  MFU = 46.51%

Performance Model

$$T_{\text{proj}}(\text{TK}, H, I, E) = \underbrace{\alpha \cdot \text{TK} \cdot H \cdot I}_{\text{compute}} + \underbrace{\beta \cdot \text{TK} \cdot D \cdot \ln\left(1 + \frac{\text{TK}}{\text{TK}_0}\right)}_{\text{L2 miss penalty}} + \underbrace{\gamma \cdot E}_{\text{expert setup}}$$

$$\text{MFU} = \frac{18 \cdot \text{TK} \cdot H \cdot I}{P \cdot T_{\text{proj}}}, \quad D = \max(H, 2I), \quad P = 4500 \text{ TFLOPS (FP8 peak)}$$

Parameter Value Derivation
α 7.491 × 10⁻⁹ = 18/(P · η_tc), η_tc = 53.4%
β 9.520 × 10⁻⁷ L2 miss penalty per (access × refetch-width)
TK₀ 39 747 ≈ (L2/E) / (tile_M · tile_N · b_elem) · tile_M = (96MB/8) / 64KB · 256 ≈ 49K (fitted: 40K)
γ 21.72 µs Per-expert TMA descriptor + metadata

Fit: MAPE = 3.79%, MFU MAE = 1.49%, N = 104 (uniform 4E × 4HI × 6TK grid + 16 large-TK).

Term semantics

# Term Physical origin Why this functional form
1 α · TK · H · I 6 CUTLASS GEMMs totalling 18 · TK · H · I FLOPs. η_tc < 1 due to persistent-scheduler per-tile overhead and DGated epilogue register pressure (168 regs → 1 block/SM). Proportional to F/P at constant η_tc. Shape-invariant: same tile config (256², cluster 2×1) across all (H, I).
2 β · TK · D · ln(1 + TK/TK₀) Weight-tile L2 reuse breakdown. Scheduler cycles TK/tile_M tiles; when this exceeds L2 capacity per expert (~192 sets), multi-stream SM contention drives miss rate proportional to ln(TK/TK₀). Each miss refetches D bytes. D = max(H, 2I): widest tile dimension among the 6 GEMMs' B-tensors. TK: total L2 accesses ∝ token count. ln(1+x): smooth onset; empirical consensus for shared-cache thrashing under randomized multi-tenant load.
3 γ · E Per-expert one-time: TMA descriptor, weight-cache version check, routing histogram. O(1) per expert, independent of TK, H, I. E=8: 174 µs (< 7%); E=128: 2784 µs (dominant only when TK < 50K).

MFU phase diagram

$$\text{MFU}(\text{TK}) = \frac{18 \cdot H \cdot I}{P \left[ \alpha \cdot H \cdot I + \beta \cdot D \cdot \ln\left(1 + \frac{\text{TK}}{\text{TK}_0}\right) + \frac{\gamma \cdot E}{\text{TK}} \right]}$$

Phase Regime Limiting behavior
Rise TK ≪ TK₀ MFU ≈ 18 · H · I · TK / (P · γ · E), linear ↑
Peak TK ~ TK₀ Maximum, ∂MFU/∂TK = 0
Decay TK ≫ TK₀ MFU ∝ 1/ln(TK), logarithmic ↓

Validation (E = 8)

Shape Measured peak Model@peak Measured @4M Model@4M
Reference 45.7% @65K 46.3% 38.3% 38.3%
H=4096 I=2048 50.5% @65K 48.5% 40.6% 41.0%
H=4096 I=4096 52.4% @33K 49.1% 43.6% 41.0%
H=6144 I=3072 52.9% @33K 50.5% 44.3% 44.6%

MFU model curves and contours

Left column: MFU vs TK with model curves (solid) and nsys-measured points (markers). Dashed line = steady-state.
Right column: MFU contour in (log₁₀TK, E) space; red curve = TK* (peak MFU trajectory).

Read first:

Priority Document Purpose
1 HANDOFF.md canonical current project state, lessons, next plan
2 reports/sonic_moe_fp8_frontier_newcomer_guide.md standalone FP8/CuTe/SuperSonic-MoE newcomer guide + expert Q&A
3 docs/gemm_dgated_fp8cload.md GemmDGatedFP8CLoad design: Int16 trick, dSwiGLU, FP8 dequant, zero-materialization
4 docs/expert_interleave_weight_layout.md Expert Interleave weight layout: 5-layer benefit analysis + originality
5 reports/sonic_moe_comprehensive_analysis.md broad technical analysis and roofline/precision/perf tables
6 reports/fresh_benchmark_ws1/README.md latest sweep data and MFU fit
7 reports/reference_shape_ncu/README.md NCU resource breakdown for the 6 GEMMs (reference shape)

Important current insights:

  • GemmDGatedFP8CLoadSm100ZeroMat is the main structural bottleneck: 168 regs/thread, ~42% tensor-pipe. Do not directly add dz quant loops to this epilogue.
  • TMA reduce-add is a register/performance optimization, not higher-precision accumulation.
  • iso32 dz dual quant is measured safe for current reference-shape dz; monitor log2(block_amax/row_amax) before generalizing.
  • SonicMoEMlpNode.step() must run before optimizer.step() because it flushes native CUTLASS wgrad layout into framework main_grad.

ISO32 Weight Cache Unification

ISO32 (32×32 block) FP8 weight quantization stores one buffer per weight instead of two transposed copies by exploiting the byte-identical transpose invariant of isotropic block scaling. Forward and backward GEMM kernels consume the same physical buffer via zero-copy stride views. Controlled by SONIC_MOE_FP8_ISO32_WEIGHT=1 (default OFF).

Memory saving: 48.5% of weight FP8 cache (108 MiB at E=8, H=3072, I=1536).
Precision: Identical to baseline 1×32 path vs BF16 golden (ratio=1.0000, verified across 9 shapes).
Performance: −0.5% GPU-projection (1 fewer unique kernel, identical GEMM signatures).

ISO32 vs Baseline weight dataflow

Left: Baseline pair-kernel path (4 separate FP8 weight buffers). Right: ISO32 single-buffer path (2 buffers + stride views). Dashed boxes = eliminated allocations.

Full DeepEP pipeline dataflow

Full DeepEP pipeline: Router → All-to-All dispatch → FP8 FFN (with persistent weight cache) → Combine → reverse A2A. Three gradient paths: dx (autograd), ds (router gate), dw (main_grad accumulate).

Validation report: reports/iso32_weight_validation_report.md

Prerequisites

  • NVIDIA Hopper GPUs (H100, H200) or SM100 GPUs (Target GPU, Target GPU, Target GPU)
  • CUDA 12.9+ (13.0+ for Target GPU)
  • Python 3.12+, PyTorch 2.7+
  • USE_QUACK_GEMM=1 for SM100 kernels

Default Path & Environment Flags

The default code path is the fused-v2 frontier (FP8 epilogue blockscaled quant + fused-gated up-proj + TMA reduce-add wgrad + CUDA fused topk-metadata kernel). Most flags below are already set to their production-recommended defaults; they are documented here so other agents can quickly reason about the active path.

Sonic-meta routing kernel (always-on)

The fused CUDA topk→metadata kernel (sonicmoe.ernie_compat.deepep_topk_metadata_cuda) is JIT-compiled on first import and dispatched automatically by deepep_topk_to_sonic_metadata. There is no env flag to gate it: when the import succeeds it is always used; when the build fails the code falls back to the Triton-only path. To force the legacy path for A/B comparison, delete the build directory before import:

rm -rf $TORCH_EXTENSIONS_DIR/sonicmoe_deepep_topk_metadata_cuda  # (or matching paddle build dir)

FP8 / Frontier flags

Active production flags (set these for training):

Env flag Default What it gates
SONIC_MOE_FP8_MODE unset Master switch. perf = full FP8 frontier; mem = FP8 + stage-wise memory reuse; unset = BF16.
USE_QUACK_GEMM unset Required on SM100 for CuTeDSL FP8 GEMMs.
SONIC_MOE_FP8_WGRAD unset FP8 weight gradients (otherwise BF16 wgrad). Always 1 for production.
SONIC_MOE_FP8_ASSUME_ALIGNED 0 Skip runtime padding-check H2D sync. Set 1 for zero-sync training (requires aligned token counts).
SONIC_MOE_FP8_ISO32_WEIGHT 0 ISO32 weight cache unification: stores ONE fp8 buffer per weight (saves 48.5% weight cache memory).
SONIC_MOE_FP8_RECOMPUTE_Z 0 Skip storing z_fp8 in fwd; rerun up-proj in bwd. Saves ~213 MiB/layer, costs ~5–15% extra time.
SONIC_MOE_STAGEWISE_MEMORY 0 Free activations eagerly between stages. Saves ~1.0–1.5 GB at reference shape, costs 3–5%.
SONIC_MOE_CACHE_DIR ~/.cache/sonicmoe JIT compile-cache directory.

Hardcoded defaults (always-on in the frontier path, no need to set):

Flag Value Notes
SONIC_MOE_FP8_FUSED_GATED 1 Fused SwiGLU+quant in up-proj epilogue. Non-fused path is deprecated.
SONIC_MOE_FP8_EPILOGUE_QUANT 1 FP8 quant inside GEMM epilogue. Always active when FP8 enabled.
SONIC_MOE_FP8_FUSED_SWIGLU_QUANT 1 Fused SwiGLU+rowquant on y1 path. Always active.
SONIC_MOE_FP8_SAVE_Z_FP8 1 Save z as FP8 for backward (vs BF16). Always active.

Deprecated / do-not-use:

Flag Reason
SONIC_MOE_FP8_RECOMPUTE_OPT_B Produces illegal-instruction on non-uniform routing. Will be removed.
SONIC_MOE_FP8_WGRAD_BETA_ACCUM Legacy fused-beta epilogue (86 regs). TMA reduce-add (50 regs) is strictly better.
SONIC_MOE_FP8_UPPROJ_EPILOGUE_PRECISION Always fp8-blockscaled. No other value tested or supported.
SONIC_MOE_FP8_DOWNPROJ_MAINLOOP_PRECISION Always fp8-blockscaled.
SONIC_MOE_FP8_DOWNPROJ_WEIGHT_PRECISION Always bf16. FP8 weight not supported in down-proj mainloop.
SONIC_MOE_FP8_FUSED_ZY1_QUANT Experimental dual z+y1 quant kernel. Slower than separated path on production shapes.
SONIC_MOE_FP8_BLOCKSCALED_EXPERT_CAPACITY Benchmark-only override. Never set in production.

Production launcher:

USE_QUACK_GEMM=1 \
SONIC_MOE_FP8_MODE=perf \
SONIC_MOE_FP8_WGRAD=1 \
SONIC_MOE_FP8_ASSUME_ALIGNED=1 \
TRITON_PTXAS_PATH=/usr/local/cuda-13.0/bin/ptxas \
python train.py

Paddle Integration (PaddleFleet)

The race-fix-paddle branch integrates SonicMoE into Paddle via paddle.compat.enable_torch_proxy. Production entry point: SonicMoEMlpNode.

Training Loop (S74+ contract — node.step() BEFORE optimizer.step())

from sonicmoe.ernie_compat import SonicMoEMlpNode

node = SonicMoEMlpNode(experts, n_experts=E, hidden_size=H, intermediate_size=I)

# ── Cold start: first fwd+bwd triggers JIT compilation (~42s) ──
# For explicit warmup before training:
#   from sonicmoe.jit_warmup import warmup_jit
#   warmup_jit(E=8, H=3072, I=1536, device="cuda")

for step in range(num_steps):
    for mb in microbatches:
        out = node(x, tokens_per_expert, dispatched_indices, dispatched_probs)
        out.backward(grad)
    node.step()           # MUST run BEFORE optimizer.step():
                          # converts native CUTLASS [E,2I,H] → framework split-half
                          # [E,H,2I] in-place into expert.weight.main_grad,
                          # which the optimizer then reads.
    optimizer.step()
    optimizer.clear_grad()

main_grad is lazy-allocated on first backward (S74 follow-up); inference / warmup-only flows pay zero main_grad memory. Weight cache invalidation is automatic via (data_ptr, _inplace_version(w)) keys — no manual clear_* needed after in-place optimizer updates.

Cold Start vs Hot Start

Phase What happens Time
Import CUDA topk metadata kernel compiled ~4s
1st fwd+bwd CuTe CUTLASS GEMM + all Triton kernels JIT compiled ~42s
2nd fwd+bwd All caches hit, steady-state 0.05s
New seqlen CuTe GEMM: 0 recompile (dynamic dims via mark_layout_dynamic). Triton: ~2.5s per new TK (cached in ~/.triton/cache across sessions) 0-2.5s
After optimizer.step() Call node.step() → flushes native-layout wgrad to per-expert main_grad, clears weight/FP8/topk caches <1ms

JIT Cache Architecture

Three tiers of caching, each with different invalidation strategies:

Cache Key includes Invalidated by Max size
CuTe compile cache (_COMPILE_CACHE*) Static model dims (H, I, E, dtype, tile config) only. No token counts. Never (persistent for model lifetime) Unbounded (typically 3-8 entries)
Fast-path runtime cache (_GEMM_FAST_PATH*) Exact problem shape (total_M/K + all tensor dims) Auto-eviction at 64 entries 64
FP8 weight cache (_WEIGHT_CACHE etc.) data_ptr + _inplace_version + shape + stride node.step() / invalidate_weight_caches() 8 per cache
Triton JIT cache (~/.triton/cache/) Full kernel source hash rm -rf ~/.triton/cache Unbounded (disk)

Design principle: compile_key contains only static model dimensions — never TK, total_M, total_K, capacity, or any token-count-dependent value. Dynamic token dimensions are handled at runtime via CuTe's mark_layout_dynamic. This ensures zero CuTe recompilation when batch size or routing distribution changes.

Multi-process JIT cache on shared GPFS (production)

Production deploys 1 process per GPU; every rank on every node sees the same build/ directory and SONIC_MOE_CACHE_DIR over GPFS. The JIT layer is fully race-safe under concurrent cold start across ranks:

Layer Race-safety mechanism
C++ extensions (sonicmoe/jit.py) Cross-process FileLock at <parent>/.{module_name}.lock (stable inode survives build/<module>/ wipe). After lock acquisition the worker first tries _try_import_prebuilt — directly dlopen-imports <build_dir>/<name>/<name>.so (PYBIND11) bypassing paddle's racy load() whenever the artifact already exists. Only the first rank actually compiles; later ranks fast-path import.
Triton kernels (~/.triton/cache, or $TRITON_CACHE_DIR) Per-key subdirs; same-key concurrent writes use Triton's atomic temp+rename.
Quack / CuTe autotuner (autotuner.py:check_disk_cache) sha256(VERSION + tuning_key + configs) → per-key file. Different shapes write to different files; same-key writes are atomic. Once any rank has tuned a (dtype, K, config) combination, all other ranks read the same JSON on the next iteration.
paddle.enable_compat() ordering _get_cpp_function re-arms the torch.utils.hipify proxy blocker (install_quack_paddle_compat()) inside the FileLock immediately before cpp_extension.load(), so the patches are live regardless of whether the consumer imported sonicmoe before or after paddle.enable_compat().

Heterogeneous-shape concurrent cold starts (e.g. GPU0 hits total_K=4096, GPU1 hits total_K=8192 simultaneously on a fresh cache) are explicitly covered by tests/ops/test_jit_concurrent_heterogeneous.py and gated in CI under phase jit-concurrent.

Recommended deployment: pre-warm once per cluster, then point every rank at the shared cache:

python -m sonicmoe.cli.warmup --E 32 --H 3072 --I 1536 \
    --cache-dir /gpfs/.../sonicmoe_jit_e32_h3072_i1536
# every rank
export SONIC_MOE_CACHE_DIR=/gpfs/.../sonicmoe_jit_e32_h3072_i1536

Skipping pre-warm is also safe: ranks will JIT-compile in parallel under the locks above; the cluster pays the cold-start cost once and converges to steady-state on iteration ≥ 2.

Custom integration (PaddleFleet / external trainers)

  1. Call paddle.enable_compat() (or paddle.compat.enable_torch_proxy(scope={"sonicmoe", "quack", "triton"})) before import sonicmoe. The reverse order is also tolerated — _get_cpp_function re-arms patches lazily — but the recommended order avoids a duplicate proxy install on the hot path.
  2. Set SONIC_MOE_CACHE_DIR to a shared GPFS path so all ranks reuse Triton/Quack disk caches; otherwise each rank caches under ~/.cache/sonicmoe.
  3. Construct one SonicMoEMlpNode per fused expert group (see Paddle Integration above). Call node.step() BEFORE optimizer.step() each microbatch; this flushes native-layout wgrad into per-expert main_grad.
  4. For diagnostics: SONIC_MOE_VALIDATE_DISPATCH=1 enables an O(T) per-row uniqueness check on dispatched_indices before kernel launch (catches malformed routing in custom dispatchers).

Lessons learned (S77 strict CI hardening)

  • paddle.distributed.launch on paddlejob: WHITELIST env, never denylist. The paddlejob cluster exports PADDLE_TRAINERS=4 IPs, DISTRIBUTED_TRAINER_ENDPOINTS (32 entries), POD_*, EKS_POD_*, PADDLE_CURRENT_ENDPOINT, PADDLE_CLUSTER_TRAIN, PADDLE_IS_LOCAL=0, etc. Any one of these makes the launcher silently enter multi-NODE rendezvous mode and block forever (no error, no children, launcher in R state). tools/ci/multicard_smoke.py builds env from a strict prefix whitelist (PATH/LD_/HOME/USER/LANG/CUDA_/NVIDIA_/TRITON_/SONIC_MOE_/FLAGS_/NCCL_/GLOG_/OMP_/PYTHON…).
  • Place(gpu:N) is not supported is almost always a lazy device-pool init bug. The pool only registers the place named by FLAGS_selected_gpus; any other place errors. Eager-allocate a 1-element tensor right after paddle.device.set_device(...) to force pool registration before any async path (autograd backward, paddle.library proxies inside quack JIT, paddle.tensor.random.gaussian) hits it. Same root cause as the production crash from quack.autotuner._gpu_warmup.
  • _FP8Config() snapshots is_fp8_active() at construction, not at use. Always construct it INSIDE the with enable_fp8(True): block; otherwise a prior test/code path's enable_fp8(False) lingering state leaks and the wgrad path silently falls back to BF16. Belt-and-braces: set SONIC_MOE_FP8_MODE=perf BEFORE import sonicmoe.
  • Multi-process JIT cache on shared GPFS: sonicmoe.jit uses FileLock on a stable parent directory (NOT inside build_directory, which paddle wipes mid-cycle). Triton + Quack disk caches naturally support multi-rank reuse; sentinel ({cache_root}/warmup_sentinel.json) gates cold compile. Cross-rank shape divergence (rank 0 sees shape A, rank 1 sees shape B) is supported because each rank takes the file lock per-key.
  • quack import path: /usr/local/bin/python (the default python on this host) lacks the quack site-package; it lives at /root/paddlejob/share-storage/gpfs/system-public/zhangyichen/sonicmoe_deps/quack. tests/conftest.py injects this into sys.path automatically; tools/ci/jit_bench.py::_run_subprocess and tools/ci/multicard_smoke.WORKER_BODY do the same for spawned workers.
  • Hipify proxy install order matters. paddle's torch-proxy intercepts torch.utils.hipify.hipify_python lookups and raises KeyError. Fix: pre-import the real torch.utils.hipify[.hipify_python] and add it to paddle.compat.extend_torch_proxy_blocked_modules. Re-arm inside the JIT lock so the patch applies regardless of import sonicmoe vs paddle.enable_compat() ordering.
  • SM100 ptxas mismatch. Bundled Triton 3.5 ptxas does not recognize SM100. Tests/CI globally export TRITON_PTXAS_PATH=/usr/local/cuda-13.0/bin/ptxas (handled in tests/conftest.py::_ensure_sm100_ptxas). Production deployment should set this if running on Target GPU.
  • dispatched_indices per-row uniqueness contract. The kernel (sonicmoe/ernie_compat/deepep_topk_metadata_cuda/kernel.cu L315-333) assumes each row of dispatched_indices contains pairwise-distinct expert ids. Real DeepEP dispatch always satisfies this; custom test fixtures must. The optional contract validator at deepep_metadata.py (gated by SONIC_MOE_VALIDATE_DISPATCH=1) catches violations early.
  • paddle's cpp_extension.load() writes a wrapper .py whose __bootstrap__ references *_pd_.so. In this paddle build the file is plain *.so and the wrapper is non-functional; importing the .so directly via spec_from_file_location (PYBIND11 handles symbol export) is the correct fast-path.
  • Stable lock paths matter. paddle's load() may wipe its build_directory mid-cycle. The FileLock must live in the parent dir so the lock inode survives.

Gradient Contract

Gradient Mechanism Verified
dx (d/d hidden_states) Paddle autograd through _SonicMoEDeepEPFunc.backward cos=0.9975
ds (d/d dispatched_probs) _GatherRouterScores PyLayer with custom Triton scatter (no CUB cascade) cos=0.9971–0.9973
dw1, dw2 CUTLASS wgrad accumulates directly into the per-instance fused [E, 2I, H] / [E, H, I] native buffer (lazy-allocated on first backward); node.step() performs the in-place native→framework split-half layout conversion into expert.weight.main_grad before optimizer.step() reads it. cos=0.9975 / 0.9971

Precision (Session 65, FP8 vs BF16 gold, TMA Reduce-Add epilogue)

N K E I out dx dw1 dw2
128 4 4 384 0.9979 0.9975 0.9975 0.9972
128 8 8 384 0.9979 0.9975 0.9975 0.9971
512 4 8 1536 0.9979 0.9975 0.9975 0.9972
512 8 8 1536 0.9979 0.9975 0.9975 0.9972
1024 8 8 1536 0.9979 0.9975 0.9975 0.9972
256 8 32 1536 0.9979 0.9975 0.9975 0.9971

ds gradient verified via test_cold_start_e2e.py: cos=0.9972 across all 6 shapes (1024/8192/4096/2048/512/16384 tokens).

All cosine > 0.99, RRMSE < 7.6%. Shapes include E=32 (production), varying topk (4/8), small/large token counts.

Performance (nsys GPU-projection, Target GPU, current fresh data)

Current canonical data lives in reports/fresh_benchmark_ws1/. The table below supersedes older Session-65-only summaries for frontier-level comparisons.

Shape QuACK BF16 (µs) FP8 frontier (µs) Speedup vs QuACK BF16 FP8 MFU
T=1024 H=3072 I=1536 E=8 K=8 540.0 566.0 0.95x 27.32%
T=2048 H=3072 I=1536 E=8 K=8 847.7 870.1 0.97x 35.54%
T=4096 H=3072 I=1536 E=8 K=8 1533.4 1459.1 1.05x 42.39%
T=8192 H=3072 I=1536 E=8 K=8 2942.5 2659.8 1.11x 46.51%
T=16384 H=3072 I=1536 E=8 K=8 6022.1 5224.9 1.15x 47.35%
T=8192 H=4096 I=4096 E=8 K=8 10894.7 8521.7 1.28x 51.61%

Historical S53 BF16 numbers were cuBLAS/PyTorch-style and slower than the current QuACK BF16 path; relative to that historical baseline, reference-shape FP8 is ~1.37x. Always label which BF16 baseline is used.

TMA reduce-add remains part of the current default path. It replaced legacy D=A@B+1.0*C wgrad accumulation (86 regs/thread) with TMA store-side ADD (~50 regs/thread), saving 2-4% end-to-end depending on E. It does not raise accumulation precision; main_grad remains fp32 and determinism is test-gated.

See HANDOFF.md for full kernel breakdown, memory notes, and next-step priorities.

Key Files

File Purpose
sonicmoe/ernie_compat/mlp_node_v2.py SonicMoEMlpNode: production MlpNode with .forward(), .step(), .warmup()
sonicmoe/jit_warmup.py warmup_jit(E, H, I): pre-compiles all CuTe + Triton kernels
sonicmoe/quack_utils/blockscaled_fp8_gemm.py FP8 GEMM wrappers (CUTLASS + Triton quant), cache key design
sonicmoe/quack_utils/swiglu_triton.py Fused SwiGLU Triton kernels (5 production + 2 legacy bf16 variants)
sonicmoe/quack_utils/_validate.py Low-overhead input validation (dtype/stride/shape, zero GPU sync)
sonicmoe/ernie_compat/deepep_metadata.py DeepEP topk → SonicMoE routing metadata conversion
sonicmoe/functional/__init__.py _UpProjection, _DownProjection autograd Functions

Test Files

Test What it validates Run command
test_cold_start_e2e.py Cache clear → JIT → 6-shape precision (out/dx/ds/dw1/dw2) CUDA_VISIBLE_DEVICES=2 python tests/ops/test_cold_start_e2e.py
test_mlpnode_correctness_large.py Topk kernel bug-fix regression — 9 cases incl. SEQ=16K (TK=131072), skew/extreme/holes CUDA_VISIBLE_DEVICES=7 python tests/ops/test_mlpnode_correctness_large.py
test_jit_optimization.py --quick Correctness (cos>0.99), zero JIT recompile, memory CUDA_VISIBLE_DEVICES=0 python tests/ops/test_jit_optimization.py --quick
test_mlpnode_precision.py Multi-topk precision audit CUDA_VISIBLE_DEVICES=0 python tests/ops/test_mlpnode_precision.py
bench_mlpnode_mem.py E=32 fwd+bwd memory benchmark (reference shape) CUDA_VISIBLE_DEVICES=1 python tests/ops/bench_mlpnode_mem.py
bench_wgrad_epilogue.py A/B wgrad epilogue benchmark (TMA add vs fused beta) CUDA_VISIBLE_DEVICES=2 python tests/ops/bench_wgrad_epilogue.py
bench_mlpnode_topk_nsys.py nsys GPU-projection benchmark Wrap with nsys profile --resolve-symbols=false

Read First (for next developer/agent)

Priority Resource Path
1 Handoff (current state) Root HANDOFF.md — canonical current state, lessons, insights, and next plan
2 PaddleFleet migration docs/PADDLEFLEET_MIGRATION_S74.md — stream patch, node.step() ordering, lazy main_grad, Fleet's pre-fused-weight integration path
3 This README Root README.md — architecture, cache design, training loop, test matrix
4 Newcomer Guide reports/sonic_moe_fp8_frontier_newcomer_guide.md — FP8/CuTe/SonicMoE basics through expert Q&A, MFU math (Section 7)
5 GemmDGated Design Doc docs/gemm_dgated_fp8cload.md — Int16 trick, dSwiGLU SASS decomposition, FP8 blockscaled dequant, zero-materialization
6 Expert Interleave Design Doc docs/expert_interleave_weight_layout.md — full-stack benefit analysis of gate/up interleaved weight layout
7 Engineering Log reports/fp8_upgrade/engineering_log.md — historical lesson log only; current-state correction block at top
8 Environment /root/paddlejob/share-storage/gpfs/system-public/panzhaowu/env.md — machine setup, Paddle compat pitfalls, perf methodology

Project state (clean handoff, 2026-05-09): branch race-fix-paddle. FP8 frontier remains green: precision (out/dx/dw1/dw2/ds) cos≥0.997, determinism hard-gated, route-level padding active, TMA reduce-add wgrad default, node.step() MUST precede optimizer.step(), main_grad lazy-allocated. Latest reference-shape FP8 busy time is 2659.8 µs/iter (46.51% MFU, 1.63x vs true BF16). BF16 baseline (4346 µs) now verified clean via nsys (zero FP8 kernels). Read HANDOFF.md before any kernel work; the current P0 is structural dgrad1 optimization (live-range shortening / fission), NOT direct dz-quant epilogue fusion. MXFP8 128-row alignment waste analysis completed — see /panzhaowu/bkup/mxfp8_alignment_waste_analysis.pdf.

Quick-validate the frontier before resuming work:

source .runenv.sh
bash tests/run_regression.sh                  # full regression incl. determinism gate
# OR a fast spot-check:
python -m pytest tests/fp8_frontier_determinism_test.py -v
python -m pytest tests/ops/test_mlpnode_multilayer.py tests/ops/test_mlpnode_correctness_large.py \
                 tests/ops/test_colwise_quant.py tests/ops/test_rowwise_quant.py tests/ops/test_fused_quant.py
python tests/ops/test_mlpnode_precision.py     # 6-shape topk precision audit

Native PyTorch Quick Start

import torch
from sonicmoe import MoE, SonicMoEConfig
from sonicmoe.enums import ActivationType

moe = MoE(num_experts=8, num_experts_per_tok=8, hidden_size=3072,
           intermediate_size=1536, activation_function=ActivationType.SWIGLU,
           add_bias=False, std=0.02).to(device="cuda", dtype=torch.bfloat16)

x = torch.randn(8192, 3072, device="cuda", dtype=torch.bfloat16)

cfg = SonicMoEConfig(use_fp8=True, use_quack_gemm=True)
with cfg.activate():
    output, aux_loss = moe(x, use_fp8=True)

CI & Pre-commit Hook

Strict-baseline regression runner — every core mechanism (precision, multilayer/PP, quant kernels, JIT cold/warm/reload/reuse, nsys perf, multi-card) is measured and gated against tools/ci/baselines.json:

# fast pre-commit suite (~2 min on warm cache)
bash tools/ci/run_core_tests.sh --fast

# full suite — also runs jit-cold (~10 min), perf gate (nsys), multi-card
bash tools/ci/run_core_tests.sh

Install the pre-commit hook once per clone:

git config core.hooksPath .githooks

Offline pre-warm the JIT cache (Triton + Quack disk caches + sentinel) so production training skips the multi-minute first-loss cost:

python -m sonicmoe.cli.warmup --E 32 --H 3072 --I 1536 \
    --cache-dir /nfs/sonicmoe_jit_e32_h3072_i1536
# then export SONIC_MOE_CACHE_DIR=/nfs/... on every training rank

Citation

@misc{guo2025sonicmoeacceleratingmoeio,
      title={SonicMoE: Accelerating MoE with IO and Tile-aware Optimizations},
      author={Wentao Guo and Mayank Mishra and Xinle Cheng and Ion Stoica and Tri Dao},
      year={2025},
      eprint={2512.14080},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2512.14080},
}

License

Apache License 2.0 - see LICENSE.

About

An extended sonic-moe implementation, with FP8 support fully developed by agents

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 97.3%
  • Cuda 1.4%
  • Shell 1.1%
  • Other 0.2%