[Common/PyTorch] bugfix: Token-linear fused RoPE impl. for THD tensors.#3057
[Common/PyTorch] bugfix: Token-linear fused RoPE impl. for THD tensors.#3057plugyawn wants to merge 4 commits into
Conversation
Greptile SummaryThis PR fixes a launch-scaling bug in the THD fused RoPE path where the old kernel launched
Confidence Score: 4/5Safe to merge for the common case; two previously-flagged open concerns in the CUDA kernel remain unaddressed and could surface under adversarial or mismatched inputs. The new token-linear kernels correctly reproduce the original kernel's per-token math for all valid inputs — the binary search is sound, the CP-rank offset formula is identical to the existing path, and the parity test covers key edge cases. The two open issues from prior review rounds (redundant per-thread binary search, missing out-of-range guard) are the main reason not to score higher. transformer_engine/common/fused_rope/fused_rope.cu — the two new token-linear kernels and the heuristic dispatcher are the only code paths that need a second look. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["fused_rope_forward / fused_rope_backward\n(C++ binding layer)"] --> B["Read total_tokens = input.shape[0]\n(THD only)"]
B --> C["fused_rope_thd_use_token_linear\n(host heuristic)"]
C -->|"env=0 OR nseq<64 OR\nfreqs_len×nseq < 8×tokens"| E["Old kernel\ndim3 blocks(freqs_len, nseq)\nblockIdx.x = s_id, blockIdx.y = b_id\nmany dead blocks filtered at runtime"]
C -->|"env=1 OR\n(nseq≥64 AND freqs_len×nseq ≥ 8×tokens)"| D["Token-linear kernel\ndim3 blocks(total_tokens)\nblockIdx.x = t_id (linear token index)"]
D --> F["fused_rope_thd_find_seq_id\nbinary search on cu_seqlens\nto recover b_id from t_id"]
F --> G["Compute s_id = t_id - start\ncur_seqlens, begin_offset\nCP-rank freq offset"]
G --> H["fused_rope_block_forward/backward\n(same device helper as old path)"]
E --> H
Reviews (3): Last reviewed commit: "Merge branch 'main' into rope-thd-token-..." | Re-trigger Greptile |
| int t_id = blockIdx.x; | ||
| int b_id = fused_rope_thd_find_seq_id(cu_seqlens, nseq, t_id, cp_size); |
There was a problem hiding this comment.
Redundant binary search across all threads in the block
Every thread in the block calls fused_rope_thd_find_seq_id with the same arguments (t_id = blockIdx.x, nseq, cp_size) and produces an identical result. With warps_per_block = 8, that's 256 threads each doing O(log nseq) global-memory reads of cu_seqlens that could be performed once. For nseq=2401 (~12 iterations x 256 threads), each block reads ~3,072 redundant entries from cu_seqlens. Performing the search once in thread 0 and broadcasting the result via shared memory would eliminate that overhead.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| int t_id = blockIdx.x; | ||
| int b_id = fused_rope_thd_find_seq_id(cu_seqlens, nseq, t_id, cp_size); | ||
| int start = cu_seqlens[b_id] / cp_size; | ||
| int end = cu_seqlens[b_id + 1] / cp_size; | ||
| int s_id = t_id - start; | ||
| int cur_seqlens = end - start; |
There was a problem hiding this comment.
No guard for
t_id exceeding valid cu_seqlens range
The old kernel explicitly filters dead blocks with if (t_id >= end) return; before any computation. The new kernel does not: it trusts that blockIdx.x < cu_seqlens[nseq]/cp_size because total_tokens is read from input.data.shape[0]. If a caller passes a tensor with shape[0] larger than cu_seqlens[-1]/cp_size, the binary search lands on b_id = nseq-1, computes s_id = t_id - start >= cur_seqlens, and fused_rope_block_forward indexes freqs at an out-of-range s_id_for_freqs. Adding if (t_id >= (int)(cu_seqlens[nseq] / cp_size)) return; after the binary search would restore the safety property the old kernel had.
|
@plugyawn Hi, could you sign your commits? See https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst#sign-your-work @sudhakarsingh27 Could you take a look? |
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
for more information, see https://pre-commit.ci Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
331a3a0 to
6c46696
Compare
|
Thanks! Signed! fwiw I think the binary search overhead on normal cases can be reduced also, I'll probably add some improvements. |
Description
Adds a token-linear implementation of the existing THD fused RoPE path to remove a launch-scaling bug.
Addresses #2866, which finds an interesting case with RoPE scales by freqs_len × n_spans, which is pathological; it should scale by total tokens. I reproduced the issue and found that it's causing a noticeable drops on even plausibly routine shapes. For eg: the [128/512] and [512/128] cases here.
The new kernel reuses the existing
fused_rope_block_forwardandfused_rope_block_backwarddevice helpers, so the math doesn't change. All we need to do is add a THD-only path that launches one bloc/packed token.This is mostly pathological, however, so I've added a condition on the dispatch to avoid the unnecessary binary search overhead, although the overhead appears to be not-that-relevant. The condition is: token-linear only when
b >= 64and the old launch would issue ≥ 8× as many blocks as there are tokens. I'm not sure if this the usual shape of TE updates, so I could remove it!Some more relevant tests:
Microbenchmark on H100 (bf16,
h=32,d=d2=128,freqs_len=T_local=65536, single GPU):Fixes: #2866.
Type of change
Changes
Please list the changes introduced in this PR:
NVTE_FUSED_ROPE_THD_TOKEN_LINEAR=0|1.fused_rope_block_forwardandfused_rope_block_backwarddevice helpers.Checklist: