Add warp-specialized bidirectional sendrecv with adaptive stride#2728
Open
ChenYuHo wants to merge 2 commits into
Open
Add warp-specialized bidirectional sendrecv with adaptive stride#2728ChenYuHo wants to merge 2 commits into
ChenYuHo wants to merge 2 commits into
Conversation
Contributor
|
@ChenYuHo has exported this pull request. If you are a Meta employee, you can view the originating Diff in D106331816. |
a7505ee to
6eeb0f4
Compare
ChenYuHo
pushed a commit
to ChenYuHo/torchcomms
that referenced
this pull request
May 29, 2026
…a-pytorch#2728) Summary: Adds a TLX warp-specialized (WS) bidirectional sendrecv kernel that runs sender and receiver as concurrent async tasks within the same CTA, matching NCCL's kernel structure. Combined with adaptive block stride tuning, this closes the bidirectional performance gap vs NCCL from 0.66x to 0.98x-1.83x across all message sizes at runtime-SM-matched comparison. Key changes: **Warp-specialized kernel** (`sendrecv_ws.py`): - Each CTA runs sender + receiver as concurrent TLX `async_task`s. Sender uses the canonical `tlx.async_task("default", ...)` form (the only label TLX accepts for the default task); receiver uses kwargs-only `tlx.async_task(num_warps=...)` (the explicit, non-default branch). - Parameterized via `NUM_SENDER_WARPS`/`NUM_RECEIVER_WARPS` constexpr. - Grid = `nccl_grid` (same SM count as NCCL, both directions per CTA). - TLX DSL provided by `//triton:triton`; the `tests:test_sendrecv` target's TLX dep is pulled in by a later commit so this kernel can JIT-compile during testing. **Adaptive stride** (`sendrecv_op.py`): - WS path dynamically increases `BLOCK_STRIDE_BYTES` up to `256KB * MAX_BLOCKS / num_blocks` to limit pipeline steps to <=2 per block (TLX `sync_threads()` has an iteration-count-dependent barrier bug; staying <=2 iters is the safe envelope). - Reduces `fence.acq_rel.sys` operations from ~64 to ~4 per block at 1GB. - Also benefits unidirectional: 0.92x to 1.29x NCCL at 1GB with `TRITON_NVL_BLOCK_STRIDE_BYTES=2097152`. - Fail-loud guard: launcher raises `RuntimeError` if the realized iter count > 2 for the given (numel, num_blocks, BLOCK_STRIDE_BYTES) combination, instead of silently hanging the kernel. **Host validation** (`sendrecv_op.py`): - WS path requires `sender_warps + receiver_warps` to be a multiple of 4 in `[4, 16]` (TLX warp-spec hard requirement; smaller / non-multiple-of-4 totals are rejected by the downstream MLIR `TritonTLXFixup` pass). **Benchmark fixes** (`benchmark_sendrecv.py`): - Fixed ws_only path to use `nccl_grid` (was `max(nccl_grid//2, 1)`). - Added `_WS_MAX_BYTES` env var for safe WS/stable fallback at large messages. - Added WS warp config env vars and display. **Test gating** (`tests/test_sendrecv.py`): - WS sweep cases (warp configs `(2,2)`, `(4,4)`, `(3,5)` x sizes `256KB`, `1MB`) plus the WS CUDA-graph case are gated behind the `TRITON_NVL_WS_TESTS_ENABLED` env var (default off). Reason: the upstream TLX `sync_threads()` codegen still hangs at runtime even with the API-level fixes above, so leaving the WS sweep enabled would hang the test binary. The cases stay inline so they re-activate automatically once a TLX-fixed build is in place. See the comment block near `_WS_TESTS_ENABLED` for context. Bidirectional results (cap=4, H100, SM-matched, vs NCCL): Warp-specialized (32B-64MB): ``` 32B-64KB: 1.49x-1.83x (protocol overhead wins) 1MB: 1.16x (was 0.89x, +30%) 8MB: 1.11x (was 0.76x, +46%) 64MB: 1.12x (was 0.69x, +62%, needs BLOCK_STRIDE=2MB) ``` Stable kernel with 2MB stride (256MB-1GB fallback): ``` 64MB: 0.98x (was 0.69x, +42%) 256MB: 0.98x (was 0.67x, +46%) 1GB: 0.98x (was 0.66x, +48%) ``` Both Triton and NCCL use copy-based staging (verified: `NCCL_GRAPH_REGISTER=0` produces identical NCCL throughput, confirming P2P Simple protocol uses `connEltsFifo` staging, not zero-copy DirectWrite). Known limitation: TLX `sync_threads()` (`bar.sync 0`) has an iteration-count-dependent barrier bug in multi-warp async_tasks — hangs after >2 while-loop iterations even when the host adaptive-stride math has reduced the iter count to within bounds. Adaptive stride is necessary but not sufficient to make WS fully reliable on the in-tree TLX. Larger messages (>~64MB at default num_blocks) fall back to the stable bidirectional kernel; the WS test sweep is gated off until upstream TLX is fixed. Differential Revision: D106331816
6eeb0f4 to
a58045f
Compare
added 2 commits
May 29, 2026 13:28
…peer (meta-pytorch#2727) Summary: Extends the Triton NVLink copy-based sendrecv primitive from 2-rank-only to N-rank groups, enabling Ring AllGather composition where `send_peer = (rank+1) % N` and `recv_peer = (rank-1) % N` differ. API changes (`sendrecv_op.py`): - `triton_nvl_sendrecv` now takes `send_peer` and optional `recv_peer` (defaults to `send_peer` for the v1 single-peer bidirectional case). - `triton_nvl_send` / `triton_nvl_recv` rename `peer_rank` to `send_peer` / `recv_peer` respectively. - Drops the `world_size == 2` assertion. Any `world_size >= 2` is supported. Staging + signal layout (N-rank generalization): - Staging buffer sized to `_per_peer_bytes() * world_size`. Each rank's symm-mem staging is partitioned by sender rank (matches `all_to_all_single.py`): rank R's staging at `[sender * per_peer_elems, (sender+1) * per_peer_elems)` holds slots for traffic from `sender -> R`. - Signal pad uses per-(peer, block) layout: TAIL at `[peer * MBPP + block_id]`, HEAD at `[HEAD_OFFSET + peer * MBPP + block_id]` where `HEAD_OFFSET = world_size * MAX_BLOCKS_PER_PEER`. Required signal-pad size is checked against `symm_mem.get_signal_pad_size()` at first allocation (only on cold-cache path; documented in docstring). - Renames `_MAX_BLOCKS_PER_DIR` -> `_MAX_BLOCKS_PER_PEER`. - Step-state tensors reshape from `(MBPP,)` to `(world_size, MBPP)` -- one persistent monotonic counter per (peer, block) pair. Cache key changes from `group` to `(group, device)`. Kernel changes (`sendrecv.py`): - Drops `local_rank` and `peer_rank` kernel arguments. The host pre-resolves all 6 staging/signal pointers (`send_staging_buf`, `recv_staging_buf`, `send_tail_sig`, `send_head_sig`, `recv_tail_sig`, `recv_head_sig`) plus pre-sliced `sender_step_ptr` / `recver_step_ptr` for the specific `(send_peer, recv_peer)` pair. Kernel indexes by `block_id` only. - All pointers are still passed as direct typed tensor arguments (slices of base tensors), preserving the vectorized-store codegen pattern. - Keeps `send_numel` and `recv_numel` specialized. This is performance-critical: de-specializing numel makes the tile-count and copy-loop bounds runtime values and regresses 1GB cap=4 from ~8.8ms to ~134ms. - Restored the intra-block `sync_threads()` invariant comments + signal-pointer-pre-resolution narrative in the module docstring (the 4x sender/receiver `sync_threads()` sites are critical correctness hooks; future maintainers must not reorder them). - Updated module docstring + ASCII diagram to reflect the asymmetric-peer (ring-style) flow. Review hardening (this revision): - Host signal-pad slices are now bounded to exactly MBPP int64 entries via a `_sig_row` helper instead of open-ended `[start:]`. The kernel only indexes `block_id < num_blocks <= MBPP`, so an out-of-range index in a future kernel change now fails hard instead of silently reading the neighbouring peer's row. - Dropped the now-unused `MAX_BLOCKS_PER_PEER` kernel constexpr (the host pre-resolves all per-(peer, block) offsets, so the kernel never needs it). Removed from both the kernel signature and the launch call. - Documented that `triton_nvl_sendrecv` bidirectional first use on a `world_size > 2` group is **not** guarded against sparse participation and will hang (not error) without a prior `triton_nvl_sendrecv_prepare`. The send-only / recv-only paths already raise; ring-style schedules where every rank enters the same step are safe. 2-rank runtime-SM-matched benchmark results remain in the same performance envelope as D105983710. The 2-rank path now routes through pre-sliced per-peer pointers, but retains the same direct tensor argument pattern, vectorized store codegen, and numel-specialized loop bounds. Unidirectional results on H100, runtime-SM-matched (cap=4): ``` msg_size | triton GB/s | nccl GB/s | triton/nccl 1MB | 44.2 | 41.3 | 1.07x 8MB | 59.2 | 66.2 | 0.89x 64MB | 120.3 | 132.2 | 0.91x 256MB | 121.5 | 132.7 | 0.92x 1GB | 121.9 | 132.8 | 0.92x ``` cap=16: ``` 1MB | 105.2 | 67.5 | 1.56x 1GB | 337.8 | 364.9 | 0.93x ``` A regression experiment confirmed the specialization requirement: with `do_not_specialize=["send_numel", "recv_numel"]`, cap=4 1GB dropped to 8.0 GB/s (134ms). Removing that de-specialization restores 121.9 GB/s (8.8ms), matching the previous commit. Differential Revision: D106171598
Summary: Adds a TLX warp-specialized (WS) bidirectional sendrecv kernel: each CTA runs sender and receiver as concurrent async_tasks, matching NCCL's per-channel send/recv structure. At runtime-SM-matched comparison (cap=4, H100) it is the top performer among the Triton variants at large messages: 1.10x-1.13x vs NCCL across 8MB-1GB, and 1.4x-1.7x at small sizes. Intra-task synchronization uses tl.debug_barrier(), which the TLX warp-spec lowering rewrites into a per-task, correctly-sized named barrier (default region -> barrier 0; each partition -> barrier 2+idx, sized to that partition's warp count). This is the crux: a raw inline-asm bar.sync 0 (the original approach) is invisible to that lowering, so it stays a full-CTA barrier across the two independent async_tasks -- which deadlocks at higher loop-iteration counts and corrupts data at small warp counts. Switching to a compiler-visible barrier fixes both; the kernel is now deterministically correct and hang-free at any iteration count (validated to 1GB / ~128 iters). Host-side validation: sender_warps and receiver_warps must each be a power of two and sum to a power of two in [2,16] (TLX warp-spec + Triton num_warps requirement, which forces sender_warps == receiver_warps). Adaptive block stride is a performance optimization only (it lowers the loop-iteration count, reducing fence.acq_rel.sys operations -- a few % at large messages); it is NOT required for correctness, so there is no fail-loud iteration guard. Differential Revision: D106331816
a58045f to
9b6c4e5
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
Adds a TLX warp-specialized (WS) bidirectional sendrecv kernel: each CTA runs sender and receiver as concurrent async_tasks, matching NCCL's per-channel send/recv structure. At runtime-SM-matched comparison (cap=4, H100) it is the top performer among the Triton variants at large messages: 1.10x-1.13x vs NCCL across 8MB-1GB, and 1.4x-1.7x at small sizes.
Intra-task synchronization uses tl.debug_barrier(), which the TLX warp-spec lowering rewrites into a per-task, correctly-sized named barrier (default region -> barrier 0; each partition -> barrier 2+idx, sized to that partition's warp count). This is the crux: a raw inline-asm bar.sync 0 (the original approach) is invisible to that lowering, so it stays a full-CTA barrier across the two independent async_tasks -- which deadlocks at higher loop-iteration counts and corrupts data at small warp counts. Switching to a compiler-visible barrier fixes both; the kernel is now deterministically correct and hang-free at any iteration count (validated to 1GB / ~128 iters).
Host-side validation: sender_warps and receiver_warps must each be a power of two and sum to a power of two in [2,16] (TLX warp-spec + Triton num_warps requirement, which forces sender_warps == receiver_warps).
Adaptive block stride is a performance optimization only (it lowers the loop-iteration count, reducing fence.acq_rel.sys operations -- a few % at large messages); it is NOT required for correctness, so there is no fail-loud iteration guard.
Differential Revision: D106331816