Skip to content

Add warp-specialized bidirectional sendrecv with adaptive stride#2728

Open
ChenYuHo wants to merge 2 commits into
meta-pytorch:mainfrom
ChenYuHo:export-D106331816
Open

Add warp-specialized bidirectional sendrecv with adaptive stride#2728
ChenYuHo wants to merge 2 commits into
meta-pytorch:mainfrom
ChenYuHo:export-D106331816

Conversation

@ChenYuHo
Copy link
Copy Markdown

@ChenYuHo ChenYuHo commented May 28, 2026

Summary:
Adds a TLX warp-specialized (WS) bidirectional sendrecv kernel: each CTA runs sender and receiver as concurrent async_tasks, matching NCCL's per-channel send/recv structure. At runtime-SM-matched comparison (cap=4, H100) it is the top performer among the Triton variants at large messages: 1.10x-1.13x vs NCCL across 8MB-1GB, and 1.4x-1.7x at small sizes.

Intra-task synchronization uses tl.debug_barrier(), which the TLX warp-spec lowering rewrites into a per-task, correctly-sized named barrier (default region -> barrier 0; each partition -> barrier 2+idx, sized to that partition's warp count). This is the crux: a raw inline-asm bar.sync 0 (the original approach) is invisible to that lowering, so it stays a full-CTA barrier across the two independent async_tasks -- which deadlocks at higher loop-iteration counts and corrupts data at small warp counts. Switching to a compiler-visible barrier fixes both; the kernel is now deterministically correct and hang-free at any iteration count (validated to 1GB / ~128 iters).

Host-side validation: sender_warps and receiver_warps must each be a power of two and sum to a power of two in [2,16] (TLX warp-spec + Triton num_warps requirement, which forces sender_warps == receiver_warps).

Adaptive block stride is a performance optimization only (it lowers the loop-iteration count, reducing fence.acq_rel.sys operations -- a few % at large messages); it is NOT required for correctness, so there is no fail-loud iteration guard.

Differential Revision: D106331816

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 28, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented May 28, 2026

@ChenYuHo has exported this pull request. If you are a Meta employee, you can view the originating Diff in D106331816.

@ChenYuHo ChenYuHo force-pushed the export-D106331816 branch from a7505ee to 6eeb0f4 Compare May 28, 2026 23:13
@meta-codesync meta-codesync Bot changed the title Add warp-specialized bidirectional sendrecv with adaptive stride Add warp-specialized bidirectional sendrecv with adaptive stride (#2728) May 28, 2026
ChenYuHo pushed a commit to ChenYuHo/torchcomms that referenced this pull request May 29, 2026
…a-pytorch#2728)

Summary:

Adds a TLX warp-specialized (WS) bidirectional sendrecv kernel that runs sender and receiver as concurrent async tasks within the same CTA, matching NCCL's kernel structure. Combined with adaptive block stride tuning, this closes the bidirectional performance gap vs NCCL from 0.66x to 0.98x-1.83x across all message sizes at runtime-SM-matched comparison.

Key changes:

**Warp-specialized kernel** (`sendrecv_ws.py`):
- Each CTA runs sender + receiver as concurrent TLX `async_task`s. Sender uses the canonical `tlx.async_task("default", ...)` form (the only label TLX accepts for the default task); receiver uses kwargs-only `tlx.async_task(num_warps=...)` (the explicit, non-default branch).
- Parameterized via `NUM_SENDER_WARPS`/`NUM_RECEIVER_WARPS` constexpr.
- Grid = `nccl_grid` (same SM count as NCCL, both directions per CTA).
- TLX DSL provided by `//triton:triton`; the `tests:test_sendrecv` target's TLX dep is pulled in by a later commit so this kernel can JIT-compile during testing.

**Adaptive stride** (`sendrecv_op.py`):
- WS path dynamically increases `BLOCK_STRIDE_BYTES` up to `256KB * MAX_BLOCKS / num_blocks` to limit pipeline steps to <=2 per block (TLX `sync_threads()` has an iteration-count-dependent barrier bug; staying <=2 iters is the safe envelope).
- Reduces `fence.acq_rel.sys` operations from ~64 to ~4 per block at 1GB.
- Also benefits unidirectional: 0.92x to 1.29x NCCL at 1GB with `TRITON_NVL_BLOCK_STRIDE_BYTES=2097152`.
- Fail-loud guard: launcher raises `RuntimeError` if the realized iter count > 2 for the given (numel, num_blocks, BLOCK_STRIDE_BYTES) combination, instead of silently hanging the kernel.

**Host validation** (`sendrecv_op.py`):
- WS path requires `sender_warps + receiver_warps` to be a multiple of 4 in `[4, 16]` (TLX warp-spec hard requirement; smaller / non-multiple-of-4 totals are rejected by the downstream MLIR `TritonTLXFixup` pass).

**Benchmark fixes** (`benchmark_sendrecv.py`):
- Fixed ws_only path to use `nccl_grid` (was `max(nccl_grid//2, 1)`).
- Added `_WS_MAX_BYTES` env var for safe WS/stable fallback at large messages.
- Added WS warp config env vars and display.

**Test gating** (`tests/test_sendrecv.py`):
- WS sweep cases (warp configs `(2,2)`, `(4,4)`, `(3,5)` x sizes `256KB`, `1MB`) plus the WS CUDA-graph case are gated behind the `TRITON_NVL_WS_TESTS_ENABLED` env var (default off). Reason: the upstream TLX `sync_threads()` codegen still hangs at runtime even with the API-level fixes above, so leaving the WS sweep enabled would hang the test binary. The cases stay inline so they re-activate automatically once a TLX-fixed build is in place. See the comment block near `_WS_TESTS_ENABLED` for context.

Bidirectional results (cap=4, H100, SM-matched, vs NCCL):

Warp-specialized (32B-64MB):
```
  32B-64KB: 1.49x-1.83x  (protocol overhead wins)
       1MB: 1.16x         (was 0.89x, +30%)
       8MB: 1.11x         (was 0.76x, +46%)
      64MB: 1.12x         (was 0.69x, +62%, needs BLOCK_STRIDE=2MB)
```

Stable kernel with 2MB stride (256MB-1GB fallback):
```
      64MB: 0.98x         (was 0.69x, +42%)
     256MB: 0.98x         (was 0.67x, +46%)
       1GB: 0.98x         (was 0.66x, +48%)
```

Both Triton and NCCL use copy-based staging (verified: `NCCL_GRAPH_REGISTER=0` produces identical NCCL throughput, confirming P2P Simple protocol uses `connEltsFifo` staging, not zero-copy DirectWrite).

Known limitation: TLX `sync_threads()` (`bar.sync 0`) has an iteration-count-dependent barrier bug in multi-warp async_tasks — hangs after >2 while-loop iterations even when the host adaptive-stride math has reduced the iter count to within bounds. Adaptive stride is necessary but not sufficient to make WS fully reliable on the in-tree TLX. Larger messages (>~64MB at default num_blocks) fall back to the stable bidirectional kernel; the WS test sweep is gated off until upstream TLX is fixed.

Differential Revision: D106331816
@ChenYuHo ChenYuHo force-pushed the export-D106331816 branch from 6eeb0f4 to a58045f Compare May 29, 2026 04:43
Elton Ho added 2 commits May 29, 2026 13:28
…peer (meta-pytorch#2727)

Summary:

Extends the Triton NVLink copy-based sendrecv primitive from 2-rank-only to N-rank groups, enabling Ring AllGather composition where `send_peer = (rank+1) % N` and `recv_peer = (rank-1) % N` differ.

API changes (`sendrecv_op.py`):
- `triton_nvl_sendrecv` now takes `send_peer` and optional `recv_peer` (defaults to `send_peer` for the v1 single-peer bidirectional case).
- `triton_nvl_send` / `triton_nvl_recv` rename `peer_rank` to `send_peer` / `recv_peer` respectively.
- Drops the `world_size == 2` assertion. Any `world_size >= 2` is supported.

Staging + signal layout (N-rank generalization):
- Staging buffer sized to `_per_peer_bytes() * world_size`. Each rank's symm-mem staging is partitioned by sender rank (matches `all_to_all_single.py`): rank R's staging at `[sender * per_peer_elems, (sender+1) * per_peer_elems)` holds slots for traffic from `sender -> R`.
- Signal pad uses per-(peer, block) layout: TAIL at `[peer * MBPP + block_id]`, HEAD at `[HEAD_OFFSET + peer * MBPP + block_id]` where `HEAD_OFFSET = world_size * MAX_BLOCKS_PER_PEER`. Required signal-pad size is checked against `symm_mem.get_signal_pad_size()` at first allocation (only on cold-cache path; documented in docstring).
- Renames `_MAX_BLOCKS_PER_DIR` -> `_MAX_BLOCKS_PER_PEER`.
- Step-state tensors reshape from `(MBPP,)` to `(world_size, MBPP)` -- one persistent monotonic counter per (peer, block) pair. Cache key changes from `group` to `(group, device)`.

Kernel changes (`sendrecv.py`):
- Drops `local_rank` and `peer_rank` kernel arguments. The host pre-resolves all 6 staging/signal pointers (`send_staging_buf`, `recv_staging_buf`, `send_tail_sig`, `send_head_sig`, `recv_tail_sig`, `recv_head_sig`) plus pre-sliced `sender_step_ptr` / `recver_step_ptr` for the specific `(send_peer, recv_peer)` pair. Kernel indexes by `block_id` only.
- All pointers are still passed as direct typed tensor arguments (slices of base tensors), preserving the vectorized-store codegen pattern.
- Keeps `send_numel` and `recv_numel` specialized. This is performance-critical: de-specializing numel makes the tile-count and copy-loop bounds runtime values and regresses 1GB cap=4 from ~8.8ms to ~134ms.
- Restored the intra-block `sync_threads()` invariant comments + signal-pointer-pre-resolution narrative in the module docstring (the 4x sender/receiver `sync_threads()` sites are critical correctness hooks; future maintainers must not reorder them).
- Updated module docstring + ASCII diagram to reflect the asymmetric-peer (ring-style) flow.

Review hardening (this revision):
- Host signal-pad slices are now bounded to exactly MBPP int64 entries via a `_sig_row` helper instead of open-ended `[start:]`. The kernel only indexes `block_id < num_blocks <= MBPP`, so an out-of-range index in a future kernel change now fails hard instead of silently reading the neighbouring peer's row.
- Dropped the now-unused `MAX_BLOCKS_PER_PEER` kernel constexpr (the host pre-resolves all per-(peer, block) offsets, so the kernel never needs it). Removed from both the kernel signature and the launch call.
- Documented that `triton_nvl_sendrecv` bidirectional first use on a `world_size > 2` group is **not** guarded against sparse participation and will hang (not error) without a prior `triton_nvl_sendrecv_prepare`. The send-only / recv-only paths already raise; ring-style schedules where every rank enters the same step are safe.

2-rank runtime-SM-matched benchmark results remain in the same performance envelope as D105983710. The 2-rank path now routes through pre-sliced per-peer pointers, but retains the same direct tensor argument pattern, vectorized store codegen, and numel-specialized loop bounds.

Unidirectional results on H100, runtime-SM-matched (cap=4):

```
  msg_size |  triton GB/s |  nccl GB/s | triton/nccl
       1MB |        44.2  |      41.3  |   1.07x
       8MB |        59.2  |      66.2  |   0.89x
      64MB |       120.3  |     132.2  |   0.91x
     256MB |       121.5  |     132.7  |   0.92x
       1GB |       121.9  |     132.8  |   0.92x
```

cap=16:
```
       1MB |       105.2  |      67.5  |   1.56x
       1GB |       337.8  |     364.9  |   0.93x
```

A regression experiment confirmed the specialization requirement: with `do_not_specialize=["send_numel", "recv_numel"]`, cap=4 1GB dropped to 8.0 GB/s (134ms). Removing that de-specialization restores 121.9 GB/s (8.8ms), matching the previous commit.

Differential Revision: D106171598
Summary:
Adds a TLX warp-specialized (WS) bidirectional sendrecv kernel: each CTA runs sender and receiver as concurrent async_tasks, matching NCCL's per-channel send/recv structure. At runtime-SM-matched comparison (cap=4, H100) it is the top performer among the Triton variants at large messages: 1.10x-1.13x vs NCCL across 8MB-1GB, and 1.4x-1.7x at small sizes.

Intra-task synchronization uses tl.debug_barrier(), which the TLX warp-spec lowering rewrites into a per-task, correctly-sized named barrier (default region -> barrier 0; each partition -> barrier 2+idx, sized to that partition's warp count). This is the crux: a raw inline-asm bar.sync 0 (the original approach) is invisible to that lowering, so it stays a full-CTA barrier across the two independent async_tasks -- which deadlocks at higher loop-iteration counts and corrupts data at small warp counts. Switching to a compiler-visible barrier fixes both; the kernel is now deterministically correct and hang-free at any iteration count (validated to 1GB / ~128 iters).

Host-side validation: sender_warps and receiver_warps must each be a power of two and sum to a power of two in [2,16] (TLX warp-spec + Triton num_warps requirement, which forces sender_warps == receiver_warps).

Adaptive block stride is a performance optimization only (it lowers the loop-iteration count, reducing fence.acq_rel.sys operations -- a few % at large messages); it is NOT required for correctness, so there is no fail-loud iteration guard.

Differential Revision: D106331816
@ChenYuHo ChenYuHo force-pushed the export-D106331816 branch from a58045f to 9b6c4e5 Compare May 29, 2026 20:28
@meta-codesync meta-codesync Bot changed the title Add warp-specialized bidirectional sendrecv with adaptive stride (#2728) Add warp-specialized bidirectional sendrecv with adaptive stride May 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant