Skip to content

Add optimized FMHA kernel on MI355X#629

Merged
coderfeli merged 71 commits into
mainfrom
opus_align_pr
Jun 4, 2026
Merged

Add optimized FMHA kernel on MI355X#629
coderfeli merged 71 commits into
mainfrom
opus_align_pr

Conversation

@yanguahe
Copy link
Copy Markdown
Contributor

@yanguahe yanguahe commented Jun 2, 2026

Motivation

This PR adds a gfx950-optimized FlashAttention forward fast path for head_dim=128 on MI350-series GPUs. The goal is to bring FlyDSL's FMHA performance closer to the OPUS/hand-tuned baselines while keeping the same math and layout contract as the existing generic FlashAttention path.

The new path targets both bf16 and fp16, supports causal and non-causal operation, and handles both MHA and GQA/MQA (num_kv_heads <= num_heads). It is intended to improve the common D=128 attention workload on gfx950 without changing behavior for unsupported architectures, dimensions, or shapes.

Technical Details

  • Renamed the generic implementation from kernels/flash_attn_func.py to kernels/flash_attn_generic.py, and kept it as the fallback path.
  • Added kernels/flash_attn_gfx950.py, a dual-wave, software-pipelined gfx950 implementation for D=128 bf16/fp16 FlashAttention.
  • Added dispatch logic so the gfx950 fast path is selected only for supported shapes:
    • gpu_arch >= gfx950
    • head_dim == 128
    • dtype in bf16 / fp16
    • runtime seq_len >= 384 and seq_len % 256 == 0
  • Added MHA/GQA/MQA addressing support by separating Q-head and KV-head indexing (num_heads, num_kv_heads, GQA_GROUP_SIZE).
  • Implemented an explicit 8-cluster software pipeline that alternates memory and compute stages. The kernel uses double-buffered K/V LDS tiles, async global-to-LDS DMA, ds_read_b64_tr_b16 transpose LDS reads, and explicit sched_barrier / sched_group_barrier hints to constrain instruction scheduling.
  • Added a two wave-group stagger scheme for gfx950: waves 0-3 and waves 4-7 are offset by one barrier ordinal so one group can compute while the other group loads. The compute clusters are bracketed with s_setprio(1)/(0) to improve handoff between the two wave groups.
  • Added OPUS-style online softmax optimizations:
    • lazy rescale that skips O *= corr when all lanes remain below the rescale threshold
    • split exp2 placement across clusters to hide transcendental latency behind MFMA chains
    • inline-asm causal mask using immediate v_cmp_lt_i32 + v_cndmask_b32 pairs
  • Added benchmark/build helper scripts under fmha_opt_tools/ and comparison support for OPUS, aiter_ck, and hand-assembly baselines.

Test Plan

  • Test environment: AMD Instinct MI355X (gfx950), Linux 6.8.0-71-generic, ROCm 7.1.0, Triton 3.4.0+rocm7.1.0.gitf9e5bf54, GNU 11.4.0.
  • Docker image: rocm/pytorch:rocm7.1_ubuntu24.04_py3.12_pytorch_release_2.8.0.
  • Benchmark branch: opus_align. The kernels/flash_attn_gfx950.py implementation on that branch is identical to the code in this PR. The branch also contains benchmark-only test harness changes that are intentionally not included in this PR, but were used for fair comparisons:
    • run_opus_attn_bench(...) runs the OPUS reference kernel for supported bf16 shapes.
    • run_exp_isa_fmha_bench(...) runs the prebuilt hand-assembly FMHA kernel baseline for supported bf16 shapes.
  • Dependency commits used during benchmarking:
    • aiter: 45c428e54ac15b9b49d66018c8a1108b20c8336a
    • LLVM: 7f77ca0dbda4abbf9af06537b2c475f20ccd6007
  • Rebuilt the environment and ran the benchmark with:
    export FLYDSL_FLASH_ATTN_FUNC_USE_CUSTOM_LLVM=0
    ./fmha_opt_tools/build_env_and_run_benchmark.sh --aiter-dir ../aiter/ --llvm-dir ../llvm-project/ --flydsl-dir . -j 128
    This script also runs the benchmark and writes the result table to fmha_perf_compare_MI355X.csv.
  • Rebuilt the generated kernel artifacts used by the benchmark:
    • flash_attn_opus.v1.co
    • fmha_fwd_hd128_bf16_1tg_8w_256x64_350_msk0_gm0.co
    • fmha_fwd_hd128_bf16_1tg_8w_256x64_350_msk1_gm0.co
  • Ran causal FMHA validation/performance comparison for D=128 over both fp16 and bf16.
  • Covered 26 shapes spanning S=128..8192, H/Hkv=64/64,32/32,16/16,8/8, plus the GQA case H=64,Hkv=8; large-batch rows include B=4/8/16/32 at S=8192.
  • Compared FlyDSL against aiter_ck for fp16 and bf16; for bf16, also compared against OPUS and aiter_asm where those baselines are available.
  • Reproduction commands:
    export FLYDSL_FLASH_ATTN_FUNC_USE_CUSTOM_LLVM=0
    python tests/kernels/test_flash_attn_fwd.py --causal --dtype fp16 --iters 100 --compare
    python tests/kernels/test_flash_attn_fwd.py --causal --dtype bf16 --iters 100 --compare

Test Result

  • Benchmark CSV before rebase: fmha_perf_compare_MI355X.before_rebase.csv

  • Benchmark CSV after rebase: fmha_perf_compare_MI355X.csv

  • Build completed successfully.

  • All benchmark rows completed, with compared rows matching baseline correctness (MaxErr ratio 1.00x).

  • fp16 causal: FlyDSL avg 666.0 TFLOPS / 1943.2 us, MaxErr 4.88e-04; aiter_ck avg 530.3 TFLOPS / 2559.0 us. Average FlyDSL/aiter_ck throughput: 119.9%.

  • bf16 causal: FlyDSL avg 699.2 TFLOPS / 1830.6 us, MaxErr 3.91e-03; OPUS avg 826.9 TFLOPS where available, aiter_ck avg 551.4 TFLOPS, and aiter_asm avg 781.5 TFLOPS. Average ratios: 101.6% vs OPUS, 120.1% vs aiter_ck, 103.1% vs aiter_asm.

fp16 causal full table

   B      S    H  Hkv    D dtype   causal |            FlyDSL            |             OPUS             |           aiter_ck           |          aiter_asm           |    Fly/OPUS    |  Fly/aiter_ck  | Fly/aiter_asm 
                                          |   Time(us)   TFLOPS   MaxErr |   Time(us)   TFLOPS   MaxErr |   Time(us)   TFLOPS   MaxErr |   Time(us)   TFLOPS   MaxErr |  TFLOPS MaxErr |  TFLOPS MaxErr |  TFLOPS MaxErr
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
   8    128   64   64  128  fp16   causal |       18.3    117.7 4.88e-04 |         --       --       -- |       20.8    103.3 4.88e-04 |         --       --       -- |      --     -- |  113.9%  1.00x |      --     --
   8    256   64   64  128  fp16   causal |       37.9    226.5 4.88e-04 |         --       --       -- |       41.7    206.2 4.88e-04 |         --       --       -- |      --     -- |  109.9%  1.00x |      --     --
   8    512   64   64  128  fp16   causal |       81.7    420.7 4.88e-04 |         --       --       -- |       94.0    365.6 4.88e-04 |         --       --       -- |      --     -- |  115.1%  1.00x |      --     --
   1    128   64   64  128  fp16   causal |        9.5     28.2 4.88e-04 |         --       --       -- |        7.7     34.8 4.88e-04 |         --       --       -- |      --     -- |   81.1%  1.00x |      --     --
   1    256   64   64  128  fp16   causal |       13.5     79.7 4.88e-04 |         --       --       -- |       10.6    101.6 4.88e-04 |         --       --       -- |      --     -- |   78.4%  1.00x |      --     --
   1    384   64   64  128  fp16   causal |       17.5    138.0 4.88e-04 |         --       --       -- |       17.2    140.8 4.88e-04 |         --       --       -- |      --     -- |   98.1%  1.00x |      --     --
   1    512   64   64  128  fp16   causal |       19.5    220.2 4.88e-04 |         --       --       -- |       21.4    200.7 4.88e-04 |         --       --       -- |      --     -- |  109.7%  1.00x |      --     --
   1   1024   64   64  128  fp16   causal |       32.6    526.3 4.88e-04 |         --       --       -- |       40.7    422.6 4.88e-04 |         --       --       -- |      --     -- |  124.5%  1.00x |      --     --
   1   2048   64   64  128  fp16   causal |       93.1    738.3 4.88e-04 |         --       --       -- |      114.5    600.1 4.88e-04 |         --       --       -- |      --     -- |  123.0%  1.00x |      --     --
   1   4096   64   64  128  fp16   causal |      291.3    943.8 4.88e-04 |         --       --       -- |      351.4    782.3 4.88e-04 |         --       --       -- |      --     -- |  120.6%  1.00x |      --     --
   1   8192   64   64  128  fp16   causal |     1023.3   1074.5 4.88e-04 |         --       --       -- |     1265.5    868.9 4.88e-04 |         --       --       -- |      --     -- |  123.7%  1.00x |      --     --
   4   8192   64   64  128  fp16   causal |     4021.5   1093.6 4.88e-04 |         --       --       -- |     5184.0    848.4 4.88e-04 |         --       --       -- |      --     -- |  128.9%  1.00x |      --     --
   1   2048   32   32  128  fp16   causal |       52.1    659.9 4.88e-04 |         --       --       -- |       67.2    511.5 4.88e-04 |         --       --       -- |      --     -- |  129.0%  1.00x |      --     --
   1   4096   32   32  128  fp16   causal |      156.6    877.6 4.88e-04 |         --       --       -- |      192.4    714.4 4.88e-04 |         --       --       -- |      --     -- |  122.8%  1.00x |      --     --
   1   8192   32   32  128  fp16   causal |      532.0   1033.4 4.88e-04 |         --       --       -- |      637.0    863.1 4.88e-04 |         --       --       -- |      --     -- |  119.7%  1.00x |      --     --
   8   8192   32   32  128  fp16   causal |     3996.6   1100.4 4.88e-04 |         --       --       -- |     5137.8    856.0 4.88e-04 |         --       --       -- |      --     -- |  128.6%  1.00x |      --     --
   1   2048   16   16  128  fp16   causal |       43.9    391.6 4.88e-04 |         --       --       -- |       57.7    297.9 4.88e-04 |         --       --       -- |      --     -- |  131.5%  1.00x |      --     --
   1   4096   16   16  128  fp16   causal |       91.3    752.9 4.88e-04 |         --       --       -- |      118.2    581.6 4.88e-04 |         --       --       -- |      --     -- |  129.5%  1.00x |      --     --
   1   8192   16   16  128  fp16   causal |      286.9    958.1 4.88e-04 |         --       --       -- |      350.7    783.9 4.88e-04 |         --       --       -- |      --     -- |  122.2%  1.00x |      --     --
  16   8192   16   16  128  fp16   causal |     3893.4   1129.6 4.88e-04 |         --       --       -- |     5107.3    861.1 4.88e-04 |         --       --       -- |      --     -- |  131.2%  1.00x |      --     --
   1   2048    8    8  128  fp16   causal |       42.6    201.6 4.88e-04 |         --       --       -- |       47.1    182.3 4.88e-04 |         --       --       -- |      --     -- |  110.6%  1.00x |      --     --
   1   4096    8    8  128  fp16   causal |       78.4    438.5 4.88e-04 |         --       --       -- |      105.3    326.4 4.88e-04 |         --       --       -- |      --     -- |  134.4%  1.00x |      --     --
   1   8192    8    8  128  fp16   causal |      169.4    811.3 4.88e-04 |         --       --       -- |      220.6    622.9 4.88e-04 |         --       --       -- |      --     -- |  130.2%  1.00x |      --     --
  32   8192    8    8  128  fp16   causal |     3921.3   1121.6 4.88e-04 |         --       --       -- |     5239.3    839.4 4.88e-04 |         --       --       -- |      --     -- |  133.6%  1.00x |      --     --
  16   8192   64   64  128  fp16   causal |    16535.0   1063.9 4.88e-04 |         --       --       -- |    21207.1    829.5 4.88e-04 |         --       --       -- |      --     -- |  128.3%  1.00x |      --     --
  16   8192   64    8  128  fp16   causal |    15064.6   1167.8 4.88e-04 |         --       --       -- |    20877.7    842.6 4.88e-04 |         --       --       -- |      --     -- |  138.6%  1.00x |      --     --
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
                                AVG (all) |     1943.2    666.0 4.88e-04 |         --       --       -- |     2559.0    530.3 4.88e-04 |         --       --       -- |      --     -- |  119.9%  1.00x |      --     --
========================================================================================================================================================================================================================

bf16 causal full table

   B      S    H  Hkv    D dtype   causal |            FlyDSL            |             OPUS             |           aiter_ck           |          aiter_asm           |    Fly/OPUS    |  Fly/aiter_ck  | Fly/aiter_asm 
                                          |   Time(us)   TFLOPS   MaxErr |   Time(us)   TFLOPS   MaxErr |   Time(us)   TFLOPS   MaxErr |   Time(us)   TFLOPS   MaxErr |  TFLOPS MaxErr |  TFLOPS MaxErr |  TFLOPS MaxErr
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
   8    128   64   64  128  bf16   causal |       18.6    115.6 3.91e-03 |         --       --       -- |       20.8    103.0 3.91e-03 |         --       --       -- |      --     -- |  112.1%  1.00x |      --     --
   8    256   64   64  128  bf16   causal |       38.4    223.4 3.91e-03 |         --       --       -- |       41.7    206.2 3.91e-03 |       35.1    244.6 3.91e-03 |      --     -- |  108.4%  1.00x |   91.3%  1.00x
   8    512   64   64  128  bf16   causal |       80.1    429.1 3.91e-03 |       81.7    420.6 3.91e-03 |       91.7    374.6 3.91e-03 |       80.4    427.5 3.91e-03 |  102.0%  1.00x |  114.6%  1.00x |  100.4%  1.00x
   1    128   64   64  128  bf16   causal |       10.1     26.5 3.91e-03 |         --       --       -- |        8.3     32.2 3.91e-03 |         --       --       -- |      --     -- |   82.4%  1.00x |      --     --
   1    256   64   64  128  bf16   causal |       14.2     75.8 3.91e-03 |         --       --       -- |       10.8     99.3 3.91e-03 |       12.2     88.1 3.91e-03 |      --     -- |   76.4%  1.00x |   86.1%  1.00x
   1    384   64   64  128  bf16   causal |       18.2    132.5 3.91e-03 |         --       --       -- |       16.9    142.6 3.91e-03 |         --       --       -- |      --     -- |   92.9%  1.00x |      --     --
   1    512   64   64  128  bf16   causal |       20.0    214.6 3.91e-03 |       19.5    220.1 3.91e-03 |       21.2    202.2 3.91e-03 |       26.9    159.9 3.91e-03 |   97.5%  1.00x |  106.2%  1.00x |  134.3%  1.00x
   1   1024   64   64  128  bf16   causal |       32.1    534.6 3.91e-03 |       33.3    516.4 3.91e-03 |       39.0    440.5 3.91e-03 |       39.9    430.9 3.91e-03 |  103.5%  1.00x |  121.3%  1.00x |  124.1%  1.00x
   1   2048   64   64  128  bf16   causal |       88.1    779.8 3.91e-03 |       88.9    772.6 3.91e-03 |      107.4    639.6 3.91e-03 |       82.4    833.6 3.91e-03 |  100.9%  1.00x |  121.9%  1.00x |   93.5%  1.00x
   1   4096   64   64  128  bf16   causal |      273.0   1006.8 3.91e-03 |      273.6   1004.8 3.91e-03 |      338.0    813.2 3.91e-03 |      259.0   1061.3 3.91e-03 |  100.2%  1.00x |  123.8%  1.00x |   94.9%  1.00x
   1   8192   64   64  128  bf16   causal |      957.6   1148.2 3.91e-03 |      965.2   1139.1 3.91e-03 |     1227.2    895.9 3.91e-03 |      925.0   1188.6 3.91e-03 |  100.8%  1.00x |  128.2%  1.00x |   96.6%  1.00x
   4   8192   64   64  128  bf16   causal |     3785.5   1161.8 3.91e-03 |     3804.3   1156.1 3.91e-03 |     4981.1    882.9 3.91e-03 |     3663.9   1200.4 3.91e-03 |  100.5%  1.00x |  131.6%  1.00x |   96.8%  1.00x
   1   2048   32   32  128  bf16   causal |       50.5    680.3 3.91e-03 |       53.8    639.2 3.91e-03 |       64.7    531.1 3.91e-03 |       60.4    569.1 3.91e-03 |  106.4%  1.00x |  128.1%  1.00x |  119.5%  1.00x
   1   4096   32   32  128  bf16   causal |      148.1    927.9 3.91e-03 |      152.1    903.7 3.91e-03 |      185.1    742.6 3.91e-03 |      133.3   1030.9 3.91e-03 |  102.7%  1.00x |  124.9%  1.00x |   90.0%  1.00x
   1   8192   32   32  128  bf16   causal |      498.4   1103.0 3.91e-03 |      496.8   1106.6 3.91e-03 |      612.0    898.2 3.91e-03 |      467.1   1177.0 3.91e-03 |   99.7%  1.00x |  122.8%  1.00x |   93.7%  1.00x
   8   8192   32   32  128  bf16   causal |     3776.1   1164.7 3.91e-03 |     3799.7   1157.5 3.91e-03 |     4897.9    897.9 3.91e-03 |     3662.6   1200.8 3.91e-03 |  100.6%  1.00x |  129.7%  1.00x |   97.0%  1.00x
   1   2048   16   16  128  bf16   causal |       43.7    392.9 3.91e-03 |       45.4    378.3 3.91e-03 |       56.6    303.6 3.91e-03 |       52.7    325.7 3.91e-03 |  103.9%  1.00x |  129.4%  1.00x |  120.6%  1.00x
   1   4096   16   16  128  bf16   causal |       88.1    780.2 3.91e-03 |       90.2    761.6 3.91e-03 |      113.2    607.3 3.91e-03 |       97.6    704.5 3.91e-03 |  102.4%  1.00x |  128.5%  1.00x |  110.8%  1.00x
   1   8192   16   16  128  bf16   causal |      268.4   1024.2 3.91e-03 |      275.5    997.7 3.91e-03 |      331.9    828.3 3.91e-03 |      237.4   1158.1 3.91e-03 |  102.6%  1.00x |  123.6%  1.00x |   88.4%  1.00x
  16   8192   16   16  128  bf16   causal |     3672.3   1197.6 3.91e-03 |     3706.3   1186.7 3.91e-03 |     4897.2    898.1 3.91e-03 |     3661.4   1201.2 3.91e-03 |  100.9%  1.00x |  133.4%  1.00x |   99.7%  1.00x
   1   2048    8    8  128  bf16   causal |       42.9    200.3 3.91e-03 |       43.0    199.7 3.91e-03 |       47.2    181.9 3.91e-03 |       50.6    169.7 3.91e-03 |  100.3%  1.00x |  110.1%  1.00x |  118.0%  1.00x
   1   4096    8    8  128  bf16   causal |       75.7    453.8 3.91e-03 |       78.8    436.2 3.91e-03 |      103.3    332.7 3.91e-03 |       87.1    394.5 3.91e-03 |  104.0%  1.00x |  136.4%  1.00x |  115.0%  1.00x
   1   8192    8    8  128  bf16   causal |      162.2    847.3 3.91e-03 |      164.5    835.5 3.91e-03 |      212.4    647.2 3.91e-03 |      167.4    821.2 3.91e-03 |  101.4%  1.00x |  130.9%  1.00x |  103.2%  1.00x
  32   8192    8    8  128  bf16   causal |     3702.0   1188.0 3.91e-03 |     3751.1   1172.5 3.91e-03 |     5011.7    877.5 3.91e-03 |     3667.5   1199.2 3.91e-03 |  101.3%  1.00x |  135.4%  1.00x |   99.1%  1.00x
  16   8192   64   64  128  bf16   causal |    15509.3   1134.3 3.91e-03 |    15506.8   1134.5 3.91e-03 |    20214.6    870.3 3.91e-03 |    14850.7   1184.6 3.91e-03 |  100.0%  1.00x |  130.3%  1.00x |   95.8%  1.00x
  16   8192   64    8  128  bf16   causal |    14221.7   1237.0 3.91e-03 |    14363.9   1224.7 3.91e-03 |    19818.4    887.7 3.91e-03 |    14622.1   1203.1 3.91e-03 |  101.0%  1.00x |  139.4%  1.00x |  102.8%  1.00x
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
                                AVG (all) |     1830.6    699.2 3.91e-03 |     2275.9    826.9 3.91e-03 |     2441.2    551.4 3.91e-03 |     2041.0    781.5 3.91e-03 |  101.6%  1.00x |  120.1%  1.00x |  103.1%  1.00x
========================================================================================================================================================================================================================

Submission Checklist

yanguahe and others added 30 commits May 8, 2026 03:27
- kernels/flash_attn_func.py: add num_kv_heads (default = num_heads = MHA);
  split STRIDE_TOKEN/global_idx/head_idx into Q (used by Q,O) and KV (used
  by K,V coop_load and DMA); recursive auto-launch threads num_kv_heads.
- tests: --num_kv_heads CLI flag; DEFAULT_CONFIGS extended to 5-tuple
  (batch, seq_len, num_heads, num_kv_heads, head_dim); add GQA-8 row
  (16, 8192, 64, 8, 128); pytorch reference expands K/V via
  repeat_interleave; table/CSV layout adds Hkv column.
- tests: fix aiter.mha_fwd / aiter.fmha_v3_fwd argcount mismatch (now uses
  keyword args for trailing optionals).

Verified on MI355X (gfx950) bf16 B=16 S=8192 D=128 iters=100 after clean
FlyDSL rebuild:
  MHA   (Hkv=64): causal 692.3 / nocausal 636.8 TFLOPS,
                  MaxErr 3.91e-03 / 2.44e-04 (PASS)
  GQA-8 (Hkv=8) : causal 687.1 / nocausal 701.7 TFLOPS,
                  MaxErr 3.91e-03 / 2.44e-04 (PASS)

Co-authored-by: Cursor <cursoragent@cursor.com>
- Change log().info(...) to log().debug(...) for the per-rewriter AST
  diff and the final transformed code dump in ast_rewriter.py.
- These messages are noisy compile-time diagnostics; users only need
  them when actively debugging the AST rewrite. They now require
  FLYDSL_DEBUG_LOG_LEVEL=DEBUG (with FLYDSL_DEBUG_LOG_TO_CONSOLE=1
  or FLYDSL_DEBUG_LOG_TO_FILE=...) to surface, instead of the more
  commonly enabled INFO level.

Co-authored-by: Cursor <cursoragent@cursor.com>
- Add opus_attn C++ templates (d128/d512 causal/noncausal), host/driver, and setuptools build
- Add opus_attn Python wrapper, compare/rebuild/install scripts, and README
- Add run.sh launcher and test_flash_opus_attn kernel tests
- Extend test_flash_attn_func DEFAULT_CONFIGS with GQA kv_heads=num_heads line

Verified: not benchmarked in this commit (functional tests available under tests/kernels/)
Co-authored-by: Cursor <cursoragent@cursor.com>
- new kernels/flash_attn_opus.py: D=128 bf16 fast path for gfx950+,
  modeled after opus_attn/gqa_d128_kernel_template.hpp.
  Key OPUS optimizations included:
  * 3D grid launch (H, num_q_blocks, B) for better workload distribution
  * Double-buffered K LDS with buffer_load_dwordx4_lds DMA
  * ds_read_tr16_b64 HW-transpose V reads
  * Online softmax with lazy rescale (ballot+read_exec): clamps
    row_max := m_running when (row_max - m_row) <= 8.0 across all lanes,
    skipping O *= corr (corr == 1)
  * s_setprio(1)/s_setprio(0) brackets around GEMM2/rescale cluster
  * s_nop 15 + s_nop 7 yield window after s_setprio(0)
  * Causal mask via per-element v_cmp_lt + v_cndmask (select chain)
- kernels/flash_attn_func.py: add OPUS dispatcher
  * built only when head_dim=128, dtype=bf16, gfx950+
  * runtime dispatch when seq_len >= 384 and seq_len % 256 == 0
  * gated by FLYDSL_ENABLE_OPUS_PATH=1 (opt-in until perf matches baseline)
  * non-eligible runtime shapes and configs fall through unchanged
- env-var knobs in flash_attn_opus.py:
  FLYDSL_OPUS_LAZY_RESCALE / FLYDSL_OPUS_SETPRIO / FLYDSL_OPUS_YIELD_NOP

Verified correctness (MaxErr threshold 8e-03):
  B=1 S=512    causal+nocausal       MaxErr=3.91e-03 / 4.88e-04  PASS
  B=1 S=8192   causal+nocausal       MaxErr=3.91e-03 / 2.44e-04  PASS
  B=4 S=2048   causal                MaxErr=3.91e-03             PASS

Performance B=16 S=8192 H=64 D=128 bf16 (MI355X):
                 default-path  OPUS-path  OPUS-C++  ASM
  causal          716 TFLOPS   636        1131      595
  nocausal        640 TFLOPS   678        1165     1249

OPUS path is currently below baseline for causal but a small win for
nocausal; ships as opt-in to avoid causal regression. Reaching the
~1074 TFLOPS target (95% of OPUS C++) requires porting OPUS's full
8-cluster pipeline, in-flight Q scaling, and exact sched_barrier_pairs
fence ordering, which remains as future work.

Co-authored-by: Cursor <cursoragent@cursor.com>
Restructure flash_attn_opus.py to match the OPUS C++ kernel layout
line-by-line at the cluster level. This is the structural foundation
on top of which subsequent perf phases (P2-P5) will land.

Structural changes (every code section labelled with its C++ line range):

- Prologue (C++ 397-436): K[0]/K[1]/V[0] async ladder + mma0 of tile 0
  + causal mask + first-half exp2 + K[2] async kickoff. Establishes the
  loop invariant v_s_0_partial = (exp2(s_lo - m_row), s_hi - m_row).

- Main loop (C++ 439-561): rewritten with `j += 2`, processing 2 KV
  tiles per iteration across 8 clusters with V double-buffer:
    * Cluster 0: V[j-2] async + K[j-2] ds_read
    * Cluster 1: GEMM0 S[j-2] + finish softmax of carried v_s[0]
    * Cluster 2: K[j] async
    * Cluster 3: GEMM2 P[j-3]@v[j-3] via step_k(0..3) +
                 lazy-rescale check (ballot/all_below) +
                 sub_row + first-half exp2(v_s[1])
    * Cluster 4: V[j-1] async + K[j-1] ds_read
    * Cluster 5: GEMM0 S[j-1] + finish softmax of v_s[1]
    * Cluster 6: K[j+1] async + causal mask on v_s[0] (= S[j-1])
    * Cluster 7: GEMM2 P[j-2]@v[j-2] via step_k(0..3) +
                 lazy-rescale + sub_row + first-half exp2(v_s[0])

- Epilogue (C++ 565-742): new 13-cluster drainer for the last 3 KV
  tiles plus the carried partial v_s[0]. Each tile uses FULL mma1
  (no lazy in epilogue, matches C++). Cluster 11 does sub_row +
  first-half exp + sched_barrier + second-half exp + cast + scale o,
  matching C++ exactly. Cluster 13 emits the final mma1.

- max_num_tiles computed from ceil_div(N, KV_TILE_SIZE) with causal
  cap (C++ 383-390) — replaces previous kv_upper token-based bound.

- Loop state extended to carry v_s_0_lo_partial / v_s_0_hi_partial
  (partial exp2 fragment) between iterations.

- s_waitcnt encodings refined per C++ template:
    * 0xC07F        : lgkmcnt(0)
    * vmcnt(k+v=4)  : main loop and epilogue clusters 0/2/4
    * vmcnt(v=2)    : epilogue clusters 6/8 (only V outstanding)
    * vmcnt(0)      : epilogue cluster 10 (drain all VMEM)

- Each cluster boundary uses the C++ triple:
  sched_barrier(0) ; s_barrier ; sched_barrier(0).

Deferred to later phases (NOT in P1):
  P2: in-flight Q scaling (drop per-FMA c_sm_scale_log2e multiplies)
  P3: sched_group_barrier_pairs / _exp_pairs scheduler discipline
  P4: inline-asm attn_mask_vec2_imm + 8 register anchors
  P5: stagger mechanism (warp_id // 4 asymmetric barriers)

Verified: MaxErr 3.91e-03 (causal) / 9.77e-04 (nocausal) at seq=512,
          MaxErr 3.91e-03 (causal) / 2.44e-04 (nocausal) at seq=8192,
          all below 8e-03 threshold. Perf intentionally NOT optimized
          in P1; current numbers will be addressed by P2-P5.
Co-authored-by: Cursor <cursoragent@cursor.com>
Align with gqa_d128_kernel_template.hpp lines 404-406: pre-multiply Q
by temperature_scale = (1/sqrt(D)) * log2(e) during the prologue load,
so that subsequent softmax math operates directly in log2 space and
the per-FMA "* sm_scale_log2e" multiplications disappear.

Changes
- Prologue Q load (lines 392-433): each bf16 MFMA pack is extended to
  f32x8, multiplied element-wise by c_sm_scale_log2e, then truncf'd
  back to bf16x8 before feeding GEMM0. Constants block is moved above
  the Q load so c_sm_scale_log2e is in scope.
- _sub_row_first_half_exp (lines 545-559): the per-element FMA chain
  fma(s, c_sm_scale_log2e, -m_row*c_sm_scale_log2e) collapses to a
  plain subf(s, m_row) since both operands are already log2-scaled.
- Main loop clusters 3 and 7: lazy-rescale check uses m_diff =
  m_tile_max - m_row (no scaling) and corr = exp2(m_row - m_new).
- Epilogue clusters 3, 7, 11: rescale_eX = exp2(m_row - row_max_eX)
  without the extra c_sm_scale_log2e multiply.

Numerics
- S becomes (sm_scale*log2e) * S_old throughout the pipeline, so
  every (row_max - m_row), corr, rescale and exp2 argument is rescaled
  by the same constant factor; final P and O values are mathematically
  identical to P1. MaxErr unchanged at every test config.

Verified
- seq_len=512   causal MaxErr 3.91e-03, nocausal MaxErr 4.88e-04 (PASS)
- seq_len=8192  causal MaxErr 3.91e-03, nocausal MaxErr 2.44e-04 (PASS)

Per the user's phased plan, only logic alignment + correctness are
checked at this phase; performance optimizations (sched_group_barrier
pairs, inline-asm causal mask, stagger) are deferred to P3-P5.

Co-authored-by: Cursor <cursoragent@cursor.com>
Replace generic sched_barrier(0) fences inside the compute clusters
with the same sched_group_barrier (MFMA/VALU/EXP) groups used by the
OPUS C++ template (gqa_d128_kernel_template.hpp lines 14-30, 455-720).
This gives the LLVM AMDGPU scheduler explicit hints about expected
instruction densities in every pipeline stage so it can reproduce the
intended MFMA-VALU-EXP interleaving.

Changes
- Define module-local mask constants matching LLVM AMDGPU semantics:
  MFMA=0x008, VALU=0x002, EXP=0x400.
- Add two Python helpers that mirror the C++ recursive templates:
  _sched_barrier_pairs(pairs, valu_cnt, group)         (C++ lines 18-23)
  _sched_barrier_exp_pairs(pairs, exp_cnt, group)      (C++ lines 25-30)
- Insert hint pairs in every loop/epilogue cluster at the same call
  sites as the C++ template:
    Main loop  Cluster 1 → group 1   (exp_pairs<6,3,1>; pairs<10,5,1>)
               Cluster 3 → group 2   (pairs<4,5,2>; pairs<6,5,2>; exp_pairs<6,3,2>)
               Cluster 5 → group 3   (exp_pairs<6,3,3>; pairs<10,5,3>)
               Cluster 7 → group 4   (pairs<4,5,4>; pairs<6,5,4>; exp_pairs<6,3,4>)
    Epilogue   Cluster 1 → group 5   (exp_pairs<6,3,5>; pairs<10,5,5>)
               Cluster 3 → group 6   (pairs<10,5,6>; exp_pairs<6,3,6>)
               Cluster 5 → group 7   (exp_pairs<6,3,7>; pairs<10,5,7>)
               Cluster 7 → group 8   (pairs<10,5,8>; exp_pairs<6,3,8>)
               Cluster 9 → group 9   (exp_pairs<6,3,9>; pairs<10,5,9>)
               Cluster 11→ group 10  (pairs<10,5,10>; exp_pairs<6,3,10>)

Behavior
- sched_group_barrier emits no instructions of its own; it constrains
  pre-/post-RA scheduling. Numerics are identical to P2: MaxErr is
  bit-for-bit the same across every test config.
- Performance currently drops because the hints require matching
  instruction densities that the FlyDSL backend does not yet produce
  (the C++ kernel relies on hand-tuned register anchors and the
  stagger mechanism added in P4/P5 to fully exploit them). The phased
  plan explicitly defers performance to later phases.

Verified
- seq_len=512   causal MaxErr 3.91e-03, nocausal MaxErr 4.88e-04 (PASS)
- seq_len=8192  causal MaxErr 3.91e-03, nocausal MaxErr 2.44e-04 (PASS)

Co-authored-by: Cursor <cursoragent@cursor.com>
Align with OPUS C++ template's hand-coded inline assembly for two
classes of low-level constructs (gqa_d128_kernel_template.hpp lines
233-249 and the eight v_s/v_p anchor sites scattered across lines
430-635).

Changes
- Add `_attn_mask_imm_single` helper that emits the
  `v_cmp_lt_i32_e64 + v_cndmask_b32_e64` pair with the threshold baked
  into the asm string as an immediate literal. `_attn_mask_vec2_imm`
  invokes it twice per (thr_x, thr_y) pair, matching the C++ semantics.
  Split into two single-asm calls rather than a 4-output struct return
  because MLIR's llvm.inline_asm with two simultaneous "=s" sgpr-pair
  outputs proved brittle.
- Rewrite `_causal_mask_inplace` to mirror C++ attn_mask_causal_tile:
  compute rel = q_pos - k_pos with k_pos = kv_start + i_n*W_N +
  lane_group*c_pack, then iterate the 8 (thr_x, thr_y) immediate pairs
  derived from the C++ static_for nest. Masks s_lo (i_n=0) and s_hi
  (i_n=1) with the same threshold list but with rel_hi = rel_lo - W_N.
- Add `_anchor_vec`, `_anchor_pair`, `_anchor_packs` helpers that emit
  `asm volatile("" : "+v"(v))`-style fences using LLVM inline asm with
  a tied "=v,0" constraint and has_side_effects=True.
- Place the 8 register anchors at the C++-matching sites:
  #1 Prologue (v_s[0])           — C++ line 430
  #2 Main Cluster 1 (v_p)        — C++ line 454
  #3 Main Cluster 3 (v_s[1])     — C++ line 489
  #4 Main Cluster 5 (v_p)        — C++ line 512
  #5 Main Cluster 7 (v_s[0])     — C++ line 553
  #6 Epi  Cluster 1 (v_p)        — C++ line 578
  #7 Epi  Cluster 3 (v_s[1])     — C++ line 607
  #8 Epi  Cluster 5 (v_p)        — C++ line 635

Behavior
- Numerics identical to P3: per-element causal-mask decision is
  q_pos < absolute_K_col, rewritten as rel < threshold. The anchors
  emit no real instructions, so the data values are unchanged.

Verified
- seq_len=512   causal MaxErr 3.91e-03, nocausal MaxErr 4.88e-04 (PASS)
- seq_len=8192  causal MaxErr 3.91e-03, nocausal MaxErr 2.44e-04 (PASS)

Performance is still well below the C++ target; the phased plan keeps
that for the remaining stages (P5 stagger and any follow-on tuning).

Co-authored-by: Cursor <cursoragent@cursor.com>
Mirror the OPUS C++ template's dual-wave-group phase-shift scheme
(gqa_d128_kernel_template.hpp lines 308, 415-418, 748-750):

    const int warp_id = __builtin_amdgcn_readfirstlane(
                            thread_id_x() / WARP_SIZE);
    const int stagger = warp_id / 4;
    ...
    if (stagger) {
        __builtin_amdgcn_sched_barrier(0);
        __builtin_amdgcn_s_barrier();
    }
    ...
    if (!stagger) {
        __builtin_amdgcn_s_barrier();
    }

Changes
- Add `OPUS_ENABLE_STAGGER` env-gated flag (FLYDSL_OPUS_STAGGER, default 0).
- Compute a SCALAR (SGPR-resident) stagger value:
    wave_id_uni = readfirstlane(tid / WARP_SIZE)
    stagger     = wave_id_uni / 4
  using `rocdl.readfirstlane` + `arith.divsi`. The result feeds two
  `arith.cmpi` results: `stagger_is_one_i1` and `stagger_is_zero_i1`.
- Add `_stagger_extra_barrier_if_one` / `_stagger_extra_barrier_if_zero`
  helpers that emit inline assembly:
      s_cmp_eq_u32 $0, 0
      s_cbranch_scc{1,0} 1f
      s_barrier
      1:
  Verified in the final ISA: `21_final_isa.s` shows the expected
  s_cmp + s_cbranch_scc + s_barrier triple at both stagger sites and
  the asymmetric barrier counts line up across waves.
- Two `if const_expr(OPUS_ENABLE_STAGGER)` gates:
    * Prologue stagger site (post-vmcnt, pre-mma0): emits the asymmetric
      barrier when ON, an unconditional `sched_barrier(0) + gpu.barrier()`
      when OFF.
    * Pre-store stagger site (post-`inv_l`, pre-global store): emits the
      complementary asymmetric barrier when ON, an unconditional
      `gpu.barrier()` when OFF.
- Add `scf` import for completeness (uses confined to helpers).

Why default OFF
- The asymmetric barrier is verified correct at the ISA level. However,
  enabling it currently produces wrong results because the FlyDSL kernel
  loads V from LDS inside Cluster 3 (via
  `_read_v_packs_for_k_substep(0, ...)`), while the C++ reference loads
  V into registers in Cluster 2 (tr_load before the cluster-2 barrier).
  With phase-shifted execution, warps 4-7 end up reading `s_v[0]` after
  warps 0-3 have already issued the Cluster-4 async_load that
  overwrites it. Hoisting the V reads is a structural change outside
  the P5 scope and is left to a follow-up phase.
- With stagger OFF, all 8 waves stay in lockstep so the LDS lifetime
  invariants hold and the kernel still produces correct results.

Verified (FLYDSL_OPUS_STAGGER unset → OFF)
- seq_len=512   causal MaxErr 3.91e-03, nocausal MaxErr 4.88e-04 (PASS)
- seq_len=8192  causal MaxErr 3.91e-03, nocausal MaxErr 2.44e-04 (PASS)

Co-authored-by: Cursor <cursoragent@cursor.com>
The P5 stagger path (FLYDSL_OPUS_STAGGER=1) previously produced wrong
results because V was being read from LDS inside Cluster 3/7/11/13 (i.e.
AFTER the cluster-2/6/10/12 s_barrier), while the C++ template loads V
into registers in the preceding cluster (tr_load BEFORE the s_barrier).
Under the dual-group phase shift, warps 4-7 would end up reading from
s_v[*] AFTER warps 0-3 had already issued the next async_load that
overwrites the same LDS buffer — a data race.

Fix: move all 6 V LDS read sites one cluster earlier so V is captured
into VGPRs BEFORE each cluster-boundary barrier, mirroring the C++
template gqa_d128_kernel_template.hpp exactly:

  Main loop:
    - Cluster 3 V[j-3] from s_v[0]   →  hoisted into Cluster 2
    - Cluster 7 V[j-2] from s_v[1]   →  hoisted into Cluster 6
  Epilogue:
    - Cluster 3 V[max-4] from s_v[0] →  hoisted into Cluster 2 (epi)
    - Cluster 7 V[max-3] from s_v[1] →  hoisted into Cluster 6 (epi)
    - Cluster 11 V[max-2] from s_v[0]→  hoisted into Cluster 10 (epi)
    - Cluster 13 V[max-1] from s_v[1]→  hoisted into Cluster 12 (epi)

With V in VGPRs across each cluster boundary, peer async_loads
overwriting the LDS buffer are harmless.

Other changes:
- Update banner comment near stagger setup: now states the path is
  correctness-safe and no longer warns that V hoisting is required.
- Update inline comments at both stagger sites and OPUS_ENABLE_STAGGER
  env-var docs to reflect the new state.
- Keep FLYDSL_OPUS_STAGGER default OFF: the asymmetric barrier
  currently regresses throughput in this port (108 → 85 TFLOPS @ S=8192
  B=16 causal) due to extra V-substep register pressure across the
  barrier. The user can opt in with FLYDSL_OPUS_STAGGER=1 once the
  scheduling tradeoff is addressed.

Verified (S=8192 B=16 H=64 D=128 bf16, MI355X, FLYDSL_OPUS_STAGGER=1):
  causal:    MaxErr 3.91e-03 (matches OPUS C++ bit-for-bit), 85.0 TFLOPS
  nocausal:  MaxErr 2.44e-04 (matches OPUS C++ bit-for-bit), 78.1 TFLOPS

Verified (S=8192 B=16 H=64 D=128 bf16, MI355X, FLYDSL_OPUS_STAGGER=0):
  causal:    MaxErr 3.91e-03, 108.5 TFLOPS (no regression vs prior tip)
  nocausal:  MaxErr 2.44e-04, 103.6 TFLOPS

Also verified at S=256 and S=512 small-scale, both stagger paths PASS.

Co-authored-by: Cursor <cursoragent@cursor.com>
The P5 stagger path was correctness-fixed in the previous commit (V LDS
reads hoisted into Cluster 2/6/10/12). To make the OPUS path fully
self-contained — i.e. setting only `FLYDSL_ENABLE_OPUS_PATH=1` actually
exercises every P1..P6 modification — flip the default for
`FLYDSL_OPUS_STAGGER` to ON.

Now the OPUS path requires no auxiliary env vars: all of LAZY_RESCALE,
SETPRIO, STAGGER, YIELD_NOP default ON, so the kernel is a faithful
end-to-end port of gqa_d128_kernel_template.hpp.

Verified with the exact user-requested command on MI355X:
  FLYDSL_ENABLE_OPUS_PATH=1 \
  python tests/kernels/test_flash_opus_attn.py --causal \
      --dtype bf16 --batch 16 --num_heads 64 --num_kv_heads 64 \
      --seq_len 8192 --head_dim 128 --iters 100 --compare

  causal: MaxErr 3.91e-03 (bit-for-bit match with OPUS C++), 85.0 TFLOPS

`FLYDSL_OPUS_STAGGER=0` still available as escape hatch for A/B testing.

Co-authored-by: Cursor <cursoragent@cursor.com>
Atomic switch of K/V LDS layout, DMA writers, register-side readers, and
causal mask to match OPUS gqa_d128_kernel_template.hpp (sections 4-5).

Key changes (all in lockstep, no env toggle):
- LDS layout: interleaved K0/V0/K1/V1 double-buffer, 68096 B total
  (K tile 8320 bf16 x 2, V tile 8704 bf16 x 2). Line stride is
  smem_linear_wave + smem_padding (K: 520 bf16, V: 544 bf16) so the
  hardware ds_read / ds_read_tr16_b64 path is bank-conflict-free.
- coop_dma_k / coop_dma_v rewritten on the OPUS u_gk/u_sk/u_gv/u_sv
  layouts: each wave owns 8 rows of N x 64 bf16 of D, two d_rpt stripes
  per buffer, raw_ptr_buffer_load_lds into the new line-strided slots.
- _read_k_packs_for_buf rewritten on OPUS u_rk: lane%32 = ((m%8)*8+m/8)
  with d_rpt-major step_k layout (outer stride 4160 bf16, inner 16,
  v_s_hi at +256 bf16 from v_s_lo).
- _read_v_packs_for_k_substep rewritten on OPUS u_rv via
  ds_read_tr16_b64 (v4f16) + 8-lane shuffle: 4 D-chunks per step_k with
  per-lane base = grp_k*(lane/32) + lane_hi*((lane%16)/4) +
  grp_n*((lane/16)%2) + lane_lo*(lane%4).
- _causal_mask_inplace: N-axis thresholds reordered to the OPUS
  permutation pi(m) = (m%8)*8 + m/8 so the 8 register-anchored vec2
  comparisons cover the permuted N positions of S; v_s_hi delta becomes
  -4 (matches OPUS smem_d_n_split=4 N-half offset).
- SmemAllocator finalized against LDS_KV_TOTAL_SIZE (68096 B).

Verified on hyg_trn_rocm7.1 / MI355X with cleared ~/.flydsl/cache:
  FLYDSL_ENABLE_OPUS_PATH=1 python tests/kernels/test_flash_opus_attn.py \
    --warmup 5 --iters 100

All 15 configs PASSED (max err 3.91e-03 < 1e-2, min cos 0.99999 > 0.99).

Co-authored-by: Cursor <cursoragent@cursor.com>
- Add ds_read_b64_tr_b16 immediate-offset inline asm; V LDS reads in OPUS issue order
- Fix K/V coop_dma global row index (n_in_warp*NUM_WAVES+wave_id); clarify u_gk comments
- Add raw buffer resource helper for O; set DMA aux 0 per loader path
- run.sh: run causal opus test without --compare; keep compare variant commented

Verified: not re-run in this commit session
Co-authored-by: Cursor <cursoragent@cursor.com>
- FLASH_ATTN_OPUS_Kernel_Analysis_Detail.md: FlyDSL OPUS kernel walkthrough
- FLASH_ATTN_OPUS_vs_CPP_Differences.md: Python vs C++ OPUS mapping
- GQA_D128_KERNEL_Analysis_Detail.md: GQA d128 reference-kernel analysis

Verified: documentation only
Co-authored-by: Cursor <cursoragent@cursor.com>
- Split K/V buffer_load_lds addressing into uniform `soffset` and per-lane `voffset`
- Drop redundant `_dma_soff`; softmax exp paths use `rocdl.exp2` instead of arith.exp2

Verified: not run in this session
Co-authored-by: Cursor <cursoragent@cursor.com>
- Add q_rsrc and use buffer_ops.buffer_load for per-step Q MFMA packs
- Remove redundant global half-vec helpers and unused Q/K/V/O pointer locals

Verified: not re-run in this commit session
Co-authored-by: Cursor <cursoragent@cursor.com>
… readers

- Move Q/K/V/O resources, DMA knobs, causal tile bounds, and scaled Q packs
  (buffer_load) ahead of MFMA/sched-group helpers for clearer codegen order
- Pass urk_base_per_lane / urv_base_per_lane into K/V ds_read helpers
- Trim repetitive cluster/epilogue commentary (behavior unchanged intent)

Verified: not re-run in this commit session
Co-authored-by: Cursor <cursoragent@cursor.com>
- Annotate prologue, steady-state clusters, softmax/rescale path, epilogue/store
- Map gqa_d128_kernel_template constructs (layouts, async_load, mma*, barriers)

Verified: comment-only annotations (no codegen logic change intended)
Co-authored-by: Cursor <cursoragent@cursor.com>
- Rename coop_dma/read/wave-max helpers to async_load_* and attn_* for reader parity
- Reorder prologue and inner-cluster sequences (GEMM0, mask, waits, DMA) to follow OPUS pipeline comments

Verified: not re-run in this session
Co-authored-by: Cursor <cursoragent@cursor.com>
- Extend launcher/kernel with stride_q_n, stride_kv_n, head_dim_runtime (defaults unchanged)
- Build Q/K/V buffer resources using base-byte offsets; DMA uses stride_kv_n for GMEM indexing
- Compute softmax scale as rsqrt(head_dim_runtime)*log2e at kernel entry
- Use uniform wave id for K/V DMA LDS line addressing; simplify LDS ptr creation for DMA
- Move Q pack load/load path after barrier; buffer_load uses row-in-block indexing via stride_q_n

Verified: not re-run in this commit session
Co-authored-by: Cursor <cursoragent@cursor.com>
- Split buffer_load → f32 extend → softmax-scale → bf16 trunc into helpers
- Keeps per-step logic identical; improves readability near async K prefetch

Verified: not re-run in this commit session
Co-authored-by: Cursor <cursoragent@cursor.com>
- Load/concat Q to one bf16 shard; scale/trunc via broadcast once; MFMA slices via shuffle
- Replace stagger inline-asm with scf.If + sched_barrier + rocdl.s_barrier for LLVM intrinsic paths
- Add convert-ub-to-llvm to RocmBackend pipeline

Verified: not re-run in this commit session
Co-authored-by: Cursor <cursoragent@cursor.com>
…cale

- Load Q packs via RawPtrBufferLoadOp (i32 vec) with byte offsets; bitcast to bf16
- Replace Vec extf/mul/truncf in scale_q_all with llvm FPExtOp/FPTruncOp + arith.mulf and fastmath attrs

Verified: not re-run in this session
Co-authored-by: Cursor <cursoragent@cursor.com>
- Add shared alias_scope domain plus lds_{k,v}{0,1} tags on raw_ptr_buffer_load_lds
- Load K mfma packs via aligned llvm.LoadOp with matching alias/noalias scopes
- Pass **kw through buffer_load_to_lds wrappers to raw_ptr_buffer_load_lds

Verified: not benchmarked in this commit session
Co-authored-by: Cursor <cursoragent@cursor.com>
- Implement attn_mask as one inline-asm block (2× v_cmp + 2× v_cndmask) for lo/hi pairs
- Split s_hi pair masking into a second constexpr loop; mark asm side effects
- Add scf.if prologue when q_start_pos < KV_TILE_SIZE (C++ attn_mask path)

Verified: not re-run in this commit session
Co-authored-by: Cursor <cursoragent@cursor.com>
… order

- Replace xor shuffle reductions with rocdl.permlane32_swap pairs for row-max/sum
- Anchor paired vec operands via llvm.inline_asm struct outputs (=v,=v,0,1)
- Match OPUS prologue: attn_sub_row before anchor; split exp2-first-half slice path

Verified: not re-run in this commit session
Co-authored-by: Cursor <cursoragent@cursor.com>
- Replace wave-sync sites with explicit s_barrier to mirror OPUS C++ path
- Keep sched_barrier pairing unchanged; no intended semantic change

Verified: not re-run in this commit session
Co-authored-by: Cursor <cursoragent@cursor.com>
yanguahe and others added 22 commits May 29, 2026 09:42
…ests

- Move flash_attn_opus.v1.s under exp_isa/ with build.sh, opus_asm_ext, and Python wrapper
- Extend trace_segment_cycles.py and specific_part.json segment fixtures
- test_flash_opus_attn: chunked PyTorch ref for large score tensors; exp_isa path hooks

Verified: not rerun in this commit session
Co-authored-by: Cursor <cursoragent@cursor.com>
…counts

- Add _anchor_v_o inline-asm pins for four 16xf32 accumulators (C++ v_o_pin pattern)
- Call after non-lazy corr scale and epilogue rescale_e3/e7/e11 paths
- Bump _sched_barrier_pairs (4,6,*) / (6,6,*) / (9,6,10) and _sched_barrier_exp_pairs (7,3,10)
- Refresh exp_isa/flash_attn_opus.v1.s annotations to match

Verified: not rerun in this commit session
Co-authored-by: Cursor <cursoragent@cursor.com>
…re harness

- trace_segment_cycles: optional specific_part trace pairs, perf-counter series,
  instruction-type count deltas, and related compare reporting
- ana_trace.sh: factor interval append + print_selected_summary; fyd_cpp_compare only
- seg_asm: update fyd_cpp_compare.json and specific_part.json fixtures
- test_flash_opus_attn: compare path uses aiter asm backend (exp_isa bench commented)
- run.sh: build exp_isa .co before compare; exp_isa/build.sh skips py ext by default
- Minor refresh of opus GQA d128 causal gfx950 ISA reference

Co-authored-by: Cursor <cursoragent@cursor.com>
- point15: drop waitcnt/barrier; anchor on setprio + s_mov + GEMM2 MFMA
- point19: anchor on setprio + MFMA + v_max3 (match current FlyDSL ISA layout)

Co-authored-by: Cursor <cursoragent@cursor.com>
- input_hand_asm_thread_trace.yaml targets hand-asm kernel with waves_per_eu=2 trace dir
- Enable ATT perf counters (VALU/MFMA busy and coexec) for segment analysis

Co-authored-by: Cursor <cursoragent@cursor.com>
…hors

- _attn_sub_row returns packed f32 vectors directly; remove _anchor_v_s calls
- Add _anchor_scalar_f32 in lazy rescale cold path so m_tile_max merges as PHI
- Re-enable llvm.intr_expect on all_below branch predicate

Co-authored-by: Cursor <cursoragent@cursor.com>
…re hooks

- Vendor fmha_fwd 256x64 gfx950 msk0/msk1 .s; fmha_asm.py + fmha_asm_ext.cc
- exp_isa/build.sh: compile opus v1 + both FMHA .co objects; setup.py extension wiring
- Add flash_attn_opus.v0.s snapshot; minor v1.s touch
- input_asm_fmha_thread_trace.yaml; ana_trace dumps asm_fmha b2_s1024 traces
- test_flash_opus_attn: run_exp_isa_fmha_bench; compare uses exp_isa FMHA asm path
- run.sh: duplicate compare invocation after exp_isa build

Verified: not rerun in this commit session
Co-authored-by: Cursor <cursoragent@cursor.com>
…ead helpers

- Replace mirrored C++ line comments with prologue/main-loop/epilogue/stagger docs
- Drop unused KV load batching constants and _waitcnt_lgkm_0_vm_n wrapper
- Use _waitcnt_vm_n / s_waitcnt(lgkmcnt(0)) directly at cluster boundaries
- Minor epilogue sched_barrier pair count tweaks (e.g. cluster 11)

Verified: not rerun in this commit session
Co-authored-by: Cursor <cursoragent@cursor.com>
- flash_attn_func.py -> flash_attn_generic.py (flash_attn_generic_kernel)
- flash_attn_opus.py -> flash_attn_gfx950.py (build_flash_attn_dualwave_swp_module,
  flash_attn_dualwave_swp_gfx950_kernel / launch_flash_attn_dualwave_swp)
- test_flash_opus_attn.py -> test_flash_attn_fwd.py; test_flash_attn_func.py -> test_flash_attn_fwd_ori.py
- Update imports, dispatcher wiring, and test references to new module names

Verified: not rerun in this commit session
Co-authored-by: Cursor <cursoragent@cursor.com>
… tree

- Move ana_trace, rocprof YAML inputs, seg_asm fixtures, trace_segment_cycles,
  perf CSVs, run/compare helpers, and analysis markdown into fmha_opt_tools/
- Remove FLASH_ATTN_OPUS_vs_CPP_Differences.md (superseded by in-tree analysis)
- Update fmha_opt_tools/run.sh for flash_attn_fwd / dualwave env toggles

Co-authored-by: Cursor <cursoragent@cursor.com>
- Dispatch dualwave SWP from flash_attn_generic when dtype is bf16 or f16
- Select mfma_f32_32x32x16_bf16 vs _f16; fp16 pack/trunc for softmax P and O store
- Relax builder dtype check and update module docstrings

Verified: not rerun in this commit session
Co-authored-by: Cursor <cursoragent@cursor.com>
…JSON

- Update fyd_cpp_compare.json instruction anchors and fyd_attn specific_part
- ana_trace.sh: fix paths under fmha_opt_tools/; default log at.log
- Remove superseded main_loop_cluster0_7*.json (merged into fyd_cpp_compare)

Co-authored-by: Cursor <cursoragent@cursor.com>
…rt.json

- fyd_cpp_compare point11: match attn_sub_row v_sub before setprio 0 / s_barrier
- Drop seg_asm/specific_part.json (superseded by fyd_cpp_compare.json)

Co-authored-by: Cursor <cursoragent@cursor.com>
- Wire FLYDSL_DUALWAVE_SWP_* env vars into build and optional debug_counts launch
- Add TRIGGER_LAZY_ELSE Q/K fixture and lazy branch counter reporting
- Use chunked PyTorch reference when SDPA score workspace exceeds 128M elems
- Warm up benchmark path under torch.profiler before run_perftest

Verified: not rerun in this commit session
Co-authored-by: Cursor <cursoragent@cursor.com>
… py ext

- Add build_env_and_run_benchmark.sh (aiter, LLVM, FlyDSL, opus_attn, exp_isa, then compare bench)
- exp_isa/build.sh: re-enable setup.py build_ext --inplace for asm Python extensions

Verified: not rerun in this commit session
Co-authored-by: Cursor <cursoragent@cursor.com>
 - Absorb rocm/main infrastructure: CI, custom LLVM tools, external_llvm,
   backend cmake, and related FlyDSL framework updates
 - Resolve conflicts keeping opus_align functionality (GQA, dualwave SWP,
   flash_attn_generic module naming, 5-tuple test configs)
 - From rocm/main in test_flash_attn_fwd_ori: _custom_llvm_tools_env() wrapper
   and multi-line formatting only

Co-authored-by: Cursor <cursoragent@cursor.com>
…ilds

 - Port _custom_llvm_tools_env() from test_flash_attn_fwd_ori to gfx950 test
 - Wrap build_flash_attn_dualwave_swp_module in custom LLVM context manager
 - fmha_opt_tools/run.sh: enable FLYDSL_FLASH_ATTN_FUNC_USE_CUSTOM_LLVM=1
 - Refresh run.sh benchmark command examples (fp16/GQA/compare variants)

Co-authored-by: Cursor <cursoragent@cursor.com>
 - Rename launch_flash_attn_func -> launch_flash_attn_generic for module clarity
 - rocm backend: remove convert-ub-to-llvm from default LLVM lowering pipeline
 - fmha_opt_tools/run.sh: IR dump + cache clear dev workflow; 8192 compare bench

Co-authored-by: Cursor <cursoragent@cursor.com>
@yanguahe yanguahe force-pushed the opus_align_pr branch 2 times, most recently from 175b5e7 to edb4a5d Compare June 3, 2026 09:22
yanguahe and others added 2 commits June 3, 2026 08:19
 - Point docs and README to flash_attn_generic/gfx950 and test_flash_attn_fwd
 - run_benchmark.sh: 7-field GQA shapes, --num_kv_heads, legacy 6-field fallback

Co-authored-by: Cursor <cursoragent@cursor.com>
 - Remove duplicated _custom_llvm_tools_env helpers from gfx950 test harness
 - Build dualwave module with bundled FlyDSL LLVM (custom LLVM stays in ori test)

Co-authored-by: Cursor <cursoragent@cursor.com>
@coderfeli coderfeli merged commit 293b3b2 into main Jun 4, 2026
12 checks passed
@coderfeli coderfeli deleted the opus_align_pr branch June 4, 2026 11:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants