pa: vendor PA metadata scheduler, drop aiter get_pa_metadata_v1 dep#606
pa: vendor PA metadata scheduler, drop aiter get_pa_metadata_v1 dep#606fsx950223 wants to merge 5 commits into
Conversation
Port the FlyDSL-native PA worklist/reduce-map scheduler (kernels/pa_metadata.py) and its matching persistent FP8 decode kernel from the pa_update_metadata line onto latest main. This replaces the runtime dependency on aiter's get_pa_metadata_v1 / get_pa_metadata_info_v1 (cloned unpinned from aiter HEAD in CI), which drifted and produced uninitialized garbage output in the PS decode path. pa_reduce_v1 is still sourced from aiter. Files: - kernels/pa_metadata.py: new, FlyDSL impl of get_pa_metadata_v1/_info_v1 - kernels/pa_decode_fp8.py: native 256-token partition layout (partition_indptr, num_cu oversubscription); no _expand_pa_metadata_for_block_splits - kernels/pa_decode_swa.py: matching sliding-window decode/reduce - tests/kernels/test_pa.py: align harness with new launch signatures Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Vendors aiter's get_pa_metadata_v1 / get_pa_metadata_info_v1 as a native FlyDSL kernel and re-ports the matching persistent FP8 decode kernel onto current main. The new metadata kernel removes a runtime dependency on an unpinned aiter HEAD that had been producing uninitialized output, and the decode kernel is rewritten around a 256-token partition layout with CU oversubscription. Parameterization on head_dim is plumbed through both PA decode kernels and the SW helpers, replacing the hard-coded HEAD_SIZE=128 / QUERY_GROUP_SIZE=16 constants.
Changes:
- New
kernels/pa_metadata.pyimplementing the PA worklist/reduce-map scheduler as a single-warp FlyDSL kernel, drop-in foraiter.ops.attention.get_pa_metadata_v1. - Rewrite of
kernels/pa_decode_fp8.py: renamedcompile_pa_decode_ps→compile_pa_decode_metadata, new partition-granularity worklist, oversubscribed persistent grid, small-block (16/64) path with optional metadata routing,_expand_pa_metadata_for_block_splitsremoved. pa_decode_swa.pyand the FP8 helpers are parameterized onhead_dim/qkhe_loop/vhe_loop/q_lanes_per_head; the test harness is updated to the newget_recommended_splitsand broaderBLOCK_SIZE_OPTIONS.
Reviewed changes
Copilot reviewed 3 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| kernels/pa_metadata.py | New FlyDSL worklist scheduler replacing the aiter dependency. |
| kernels/pa_decode_fp8.py | Renamed/restructured PS decode kernel, new partition layout, small-block path, head_dim parameterization, switches metadata source to local module. |
| kernels/pa_decode_swa.py | Removes HEAD_SIZE/QUERY_GROUP_SIZE constants and threads head_dim / qkhe_loop / vhe_loop through helpers. |
| tests/kernels/test_pa.py | Updates harness to get_recommended_splits, drops obsolete block_size/head_size guards, expands sliding-window block_size matrix. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @flyc.kernel | ||
| def pa_decode_ps_kernel( | ||
| @flyc.kernel(known_block_size=(BLOCK_THREADS, 1, 1)) | ||
| def pa_decode_metadata_kenrel( |
| # value_attrs=_mfma_agpr_value_attrs(), | ||
| ).launch(grid=(num_sm, 1, 1), block=(BLOCK_THREADS, 1, 1), stream=stream) | ||
|
|
||
| launch_pa_decode_ps.compile_hints["llvm_options"] = PA_MFMA_AGPR_LLVM_OPTIONS | ||
| # launch_pa_decode_metadata.compile_hints["llvm_options"] = PA_MFMA_AGPR_LLVM_OPTIONS |
| gpu = torch.cuda.current_device() | ||
| num_cu = torch.cuda.get_device_properties(gpu).multi_processor_count | ||
| cu_num = num_cu |
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Fixes the Python style CI check (ruff isort) on PR #606. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…n_kv for small blocks The grid small-block kernel compile_pa_decode_ps existed only to carry the p_scale, v_scale_per_head, and k_scale_per_token quant variants. Remove all three flags (and the grid kernel + its host wiring), and instead implement per_token_kv=True for small block_size 16/64 in the metadata path so every KV-quant mode routes through compile_pa_decode_metadata. - Delete compile_pa_decode_ps and its dedicated helpers; small blocks now always use the load-balanced worklist (metadata) kernel. - Strip p_scale / k_scale_per_token from _make_pa_phase_helpers (_k_scale_per_token_eff collapses to per_token_kv). - compile_pa_decode_metadata: remove the small-block per_token_kv NotImplementedError; add _meta_stage_small_block_kv_scales (per-token K/V scale gather across the partition's gathered pages, matching the existing scale_lds_f32 layout) + _meta_load_small_block_scale_vecs. Small blocks stage scales fresh per iteration (no loop-carried prefetch). - pa_decode_ps_launch: drop p_scale/p_scale_inv/v_scale_per_head/ k_scale_per_token params; per_token_kv = key_scale.ndim > 1; remove the grid dispatch branch and the _PA_DECODE_PS_SMALL_BLOCK_VIA_METADATA flag. Verified on MI308X (gfx942): tests/kernels/test_pa.py passes, including the block_size=16 + per_token case that previously hit NotImplementedError. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Per review, keep the grid small-block kernel compile_pa_decode_ps and split routing by block size: 16/64 -> compile_pa_decode_ps, >=256 (1024) -> compile_pa_decode_metadata. The three removed scale flags (p_scale/v_scale_per_head/k_scale_per_token) stay gone. - Restore compile_pa_decode_ps, stripped of p_scale/v_scale_per_head and the old packed k_scale_per_token machinery. Implement per_token_kv=True in it: per-token K/V scales use the SAME layout as the metadata kernel ([num_blocks, num_kv_heads, block_size], indexed phys*stride_ks_block + kv_h*stride_ks_head + token_in_page), staged fresh per partition from the gathered phys-page LDS via _stage_small_block_kv_scales + _load_small_block_scale_vecs, feeding the shared per_token_kv path (_qk_and_intra_softmax preloaded_scales, _store_vmax_warp, cross-warp v_max). - pa_decode_ps_launch: dispatch 16/64 -> compile_pa_decode_ps + sw_reduce; pass per_token_kv K/V scale strides (metadata layout); 1024 stays metadata. - compile_pa_decode_metadata: revert the small-block per_token_kv support added in the previous commit (restore NotImplementedError); small blocks no longer route there. Verified on MI308X (gfx942): tests/kernels/test_pa.py passes — block_size=16 + per_token now runs through compile_pa_decode_ps; block_size=1024 through compile_pa_decode_metadata. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Port the FlyDSL-native PA worklist/reduce-map scheduler (kernels/pa_metadata.py) and its matching persistent FP8 decode kernel from the pa_update_metadata line onto latest main. This replaces the runtime dependency on aiter's get_pa_metadata_v1 / get_pa_metadata_info_v1 (cloned unpinned from aiter HEAD in CI), which drifted and produced uninitialized garbage output in the PS decode path. pa_reduce_v1 is still sourced from aiter.
Files:
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist