Skip to content

Fix int32 overflow in Triton _padded_copy pointer arithmetic#186

Open
hangg7 wants to merge 1 commit intodatabricks:mainfrom
hangg7:fix-int32-overflow-triton-kernels
Open

Fix int32 overflow in Triton _padded_copy pointer arithmetic#186
hangg7 wants to merge 1 commit intodatabricks:mainfrom
hangg7:fix-int32-overflow-triton-kernels

Conversation

@hangg7
Copy link
Copy Markdown

@hangg7 hangg7 commented Apr 2, 2026

Summary

Cast pointer offsets to tl.int64 before multiplying by NUM_COLUMNS in all 4 Triton kernels (_padded_copy, _padded_copy_wgrad, _binned_copy, _binned_copy_wgrad).

Problem

The Triton kernels compute pointer offsets as offset * NUM_COLUMNS using int32 arithmetic. In Triton, int32 * int32 stays int32 without promotion. When the product exceeds 2^31, the result wraps negative, creating a backward pointer that accesses memory before the tensor start — triggering CUDA error: an illegal memory access was encountered.

This is the same class of bug as triton-lang/triton#832.

When does it trigger?

With expert parallelism, the all-to-all dispatch can concentrate tokens on one rank due to routing imbalance. The overflow threshold for hidden_size=4096 is offset >= 524,288. At 20-30k tokens/GPU with EP, moderate routing imbalance (1.5-2x) is sufficient.

Symptoms

  • Stable at low token counts, crashes at 20k+ tokens/GPU
  • Probabilistic (depends on per-step routing distribution)
  • Manifests as NCCL "illegal memory access" errors

Fix

# Before (int32 overflow):
a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)

# After (int64 safe):
a += tl.multiple_of(offset.to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS)

Applied to all 4 kernels (8 lines). Negligible performance impact.

Testing

  • 200 steps stable at 32k seq_len bs=20 on 8xH100 (64 experts, top_k=10, EP=4)
  • Previously crashed at step 52-58 without fix

The _padded_copy and _binned_copy Triton kernels compute pointer
offsets as `offset * NUM_COLUMNS` using int32 arithmetic. In Triton,
int32 * int32 stays int32 without promotion to int64. When the product
exceeds 2^31, the result wraps negative, creating a backward pointer
that accesses memory before the tensor start — triggering
"CUDA error: an illegal memory access was encountered".

This triggers with expert parallelism at high token counts: the
all-to-all dispatch can concentrate tokens on one rank due to routing
imbalance. For hidden_size=4096, the overflow threshold is offset >=
524,288 tokens on a single rank.

Fix: cast offset and index_b to tl.int64 before the multiplication
in all 4 Triton kernels. The .to(tl.int64) adds one instruction per
thread block — negligible performance impact.

This is the same class of bug as triton-lang/triton#832.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant