Skip to content

pa: vendor PA metadata scheduler, drop aiter get_pa_metadata_v1 dep#606

Open
fsx950223 wants to merge 5 commits into
mainfrom
pa-vendored-metadata
Open

pa: vendor PA metadata scheduler, drop aiter get_pa_metadata_v1 dep#606
fsx950223 wants to merge 5 commits into
mainfrom
pa-vendored-metadata

Conversation

@fsx950223
Copy link
Copy Markdown
Contributor

Port the FlyDSL-native PA worklist/reduce-map scheduler (kernels/pa_metadata.py) and its matching persistent FP8 decode kernel from the pa_update_metadata line onto latest main. This replaces the runtime dependency on aiter's get_pa_metadata_v1 / get_pa_metadata_info_v1 (cloned unpinned from aiter HEAD in CI), which drifted and produced uninitialized garbage output in the PS decode path. pa_reduce_v1 is still sourced from aiter.

Files:

  • kernels/pa_metadata.py: new, FlyDSL impl of get_pa_metadata_v1/_info_v1
  • kernels/pa_decode_fp8.py: native 256-token partition layout (partition_indptr, num_cu oversubscription); no _expand_pa_metadata_for_block_splits
  • kernels/pa_decode_swa.py: matching sliding-window decode/reduce
  • tests/kernels/test_pa.py: align harness with new launch signatures

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

Port the FlyDSL-native PA worklist/reduce-map scheduler (kernels/pa_metadata.py)
and its matching persistent FP8 decode kernel from the pa_update_metadata line
onto latest main. This replaces the runtime dependency on aiter's
get_pa_metadata_v1 / get_pa_metadata_info_v1 (cloned unpinned from aiter HEAD in
CI), which drifted and produced uninitialized garbage output in the PS decode
path. pa_reduce_v1 is still sourced from aiter.

Files:
- kernels/pa_metadata.py: new, FlyDSL impl of get_pa_metadata_v1/_info_v1
- kernels/pa_decode_fp8.py: native 256-token partition layout (partition_indptr,
  num_cu oversubscription); no _expand_pa_metadata_for_block_splits
- kernels/pa_decode_swa.py: matching sliding-window decode/reduce
- tests/kernels/test_pa.py: align harness with new launch signatures

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings June 2, 2026 05:46
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Vendors aiter's get_pa_metadata_v1 / get_pa_metadata_info_v1 as a native FlyDSL kernel and re-ports the matching persistent FP8 decode kernel onto current main. The new metadata kernel removes a runtime dependency on an unpinned aiter HEAD that had been producing uninitialized output, and the decode kernel is rewritten around a 256-token partition layout with CU oversubscription. Parameterization on head_dim is plumbed through both PA decode kernels and the SW helpers, replacing the hard-coded HEAD_SIZE=128 / QUERY_GROUP_SIZE=16 constants.

Changes:

  • New kernels/pa_metadata.py implementing the PA worklist/reduce-map scheduler as a single-warp FlyDSL kernel, drop-in for aiter.ops.attention.get_pa_metadata_v1.
  • Rewrite of kernels/pa_decode_fp8.py: renamed compile_pa_decode_pscompile_pa_decode_metadata, new partition-granularity worklist, oversubscribed persistent grid, small-block (16/64) path with optional metadata routing, _expand_pa_metadata_for_block_splits removed.
  • pa_decode_swa.py and the FP8 helpers are parameterized on head_dim / qkhe_loop / vhe_loop / q_lanes_per_head; the test harness is updated to the new get_recommended_splits and broader BLOCK_SIZE_OPTIONS.

Reviewed changes

Copilot reviewed 3 out of 4 changed files in this pull request and generated 3 comments.

File Description
kernels/pa_metadata.py New FlyDSL worklist scheduler replacing the aiter dependency.
kernels/pa_decode_fp8.py Renamed/restructured PS decode kernel, new partition layout, small-block path, head_dim parameterization, switches metadata source to local module.
kernels/pa_decode_swa.py Removes HEAD_SIZE/QUERY_GROUP_SIZE constants and threads head_dim / qkhe_loop / vhe_loop through helpers.
tests/kernels/test_pa.py Updates harness to get_recommended_splits, drops obsolete block_size/head_size guards, expands sliding-window block_size matrix.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread kernels/pa_decode_fp8.py
@flyc.kernel
def pa_decode_ps_kernel(
@flyc.kernel(known_block_size=(BLOCK_THREADS, 1, 1))
def pa_decode_metadata_kenrel(
Comment thread kernels/pa_decode_fp8.py
Comment on lines +1744 to +1747
# value_attrs=_mfma_agpr_value_attrs(),
).launch(grid=(num_sm, 1, 1), block=(BLOCK_THREADS, 1, 1), stream=stream)

launch_pa_decode_ps.compile_hints["llvm_options"] = PA_MFMA_AGPR_LLVM_OPTIONS
# launch_pa_decode_metadata.compile_hints["llvm_options"] = PA_MFMA_AGPR_LLVM_OPTIONS
Comment thread kernels/pa_metadata.py
Comment on lines +65 to +67
gpu = torch.cuda.current_device()
num_cu = torch.cuda.get_device_properties(gpu).multi_processor_count
cu_num = num_cu
fsx950223 and others added 4 commits June 2, 2026 07:38
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Fixes the Python style CI check (ruff isort) on PR #606.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…n_kv for small blocks

The grid small-block kernel compile_pa_decode_ps existed only to carry the
p_scale, v_scale_per_head, and k_scale_per_token quant variants. Remove all
three flags (and the grid kernel + its host wiring), and instead implement
per_token_kv=True for small block_size 16/64 in the metadata path so every
KV-quant mode routes through compile_pa_decode_metadata.

- Delete compile_pa_decode_ps and its dedicated helpers; small blocks now
  always use the load-balanced worklist (metadata) kernel.
- Strip p_scale / k_scale_per_token from _make_pa_phase_helpers
  (_k_scale_per_token_eff collapses to per_token_kv).
- compile_pa_decode_metadata: remove the small-block per_token_kv
  NotImplementedError; add _meta_stage_small_block_kv_scales (per-token K/V
  scale gather across the partition's gathered pages, matching the existing
  scale_lds_f32 layout) + _meta_load_small_block_scale_vecs. Small blocks
  stage scales fresh per iteration (no loop-carried prefetch).
- pa_decode_ps_launch: drop p_scale/p_scale_inv/v_scale_per_head/
  k_scale_per_token params; per_token_kv = key_scale.ndim > 1; remove the grid
  dispatch branch and the _PA_DECODE_PS_SMALL_BLOCK_VIA_METADATA flag.

Verified on MI308X (gfx942): tests/kernels/test_pa.py passes, including the
block_size=16 + per_token case that previously hit NotImplementedError.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Per review, keep the grid small-block kernel compile_pa_decode_ps and split
routing by block size: 16/64 -> compile_pa_decode_ps, >=256 (1024) ->
compile_pa_decode_metadata. The three removed scale flags
(p_scale/v_scale_per_head/k_scale_per_token) stay gone.

- Restore compile_pa_decode_ps, stripped of p_scale/v_scale_per_head and the
  old packed k_scale_per_token machinery. Implement per_token_kv=True in it:
  per-token K/V scales use the SAME layout as the metadata kernel
  ([num_blocks, num_kv_heads, block_size], indexed
  phys*stride_ks_block + kv_h*stride_ks_head + token_in_page), staged fresh
  per partition from the gathered phys-page LDS via _stage_small_block_kv_scales
  + _load_small_block_scale_vecs, feeding the shared per_token_kv path
  (_qk_and_intra_softmax preloaded_scales, _store_vmax_warp, cross-warp v_max).
- pa_decode_ps_launch: dispatch 16/64 -> compile_pa_decode_ps + sw_reduce;
  pass per_token_kv K/V scale strides (metadata layout); 1024 stays metadata.
- compile_pa_decode_metadata: revert the small-block per_token_kv support added
  in the previous commit (restore NotImplementedError); small blocks no longer
  route there.

Verified on MI308X (gfx942): tests/kernels/test_pa.py passes — block_size=16
+ per_token now runs through compile_pa_decode_ps; block_size=1024 through
compile_pa_decode_metadata.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants