Skip to content

Triton RMSNorm Optimizations#593

Open
Micky774 wants to merge 11 commits into
devfrom
zain/rms-opt
Open

Triton RMSNorm Optimizations#593
Micky774 wants to merge 11 commits into
devfrom
zain/rms-opt

Conversation

@Micky774
Copy link
Copy Markdown
Contributor

@Micky774 Micky774 commented May 20, 2026

Description

Optimizes the Triton RMSNorm forward and backward kernels and adds an LDS-tiled FP8 transpose path. Measured 10%-50% improvements across a representative shape sweep for bf16 w/ no quantization or FP8 quant, and improvements of 3x-8x on FP8 Transpose outputs.

Benchmarks generated by this script.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Loop-invariant hoisting.
    • Fwd non-blocked path: gamma load + ZERO_CENTERED_GAMMA adjustment + 1/n_cols hoisted outside the persistent row loop.
    • Bwd non-blocked path: same gamma hoist; inv_n_cols hoisted.
    • Bwd both paths: per-row c_scalar = nf*nf*grad_sum*inv_n_cols computed once before the dx/dg loop; dx expression refactored to nf * (dz*g - c*x) (saves one multiply per element).
  • Autotune wiring for bwd kernels. _rmsnorm_bwd_triton and _rmsnorm_bwd_dg_reduce_triton now follow the impl + autotune-wrapper dispatch pattern already used by the fwd kernel. te_rmsnorm_bwd_triton takes an autotune: bool = True kwarg; when off it uses the previously-hardcoded num_warps=8 + fixed BLOCK_SIZE_M/N=128/64 reduce tile.
  • External LDS-tiled FP8 transpose kernel. New _fp8_transpose_2d_impl (+ autotune wrapper) replaces the in-kernel out_transpose_ptr + cols * stride + row_idx strided byte stores that were uncoalesced (one byte per thread to a different cache line). The new kernel does a coalesced (BLOCK_M, BLOCK_N) read, tl.trans() for LDS-staged transpose, then coalesced strided write.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Micky774 Micky774 added the ci-level 3 CI test level 3 label May 21, 2026
Comment thread transformer_engine/pytorch/triton_kernels/norms_common.py
Comment thread transformer_engine/pytorch/triton_kernels/rmsnorm.py Outdated
Comment thread README.rst Outdated
Comment thread tests/pytorch/triton_kernels/test_norms.py Outdated
Comment thread transformer_engine/pytorch/triton_kernels/norms_common.py Outdated
Comment thread transformer_engine/pytorch/triton_kernels/norms_common.py Outdated
Comment thread transformer_engine/pytorch/triton_kernels/norms_common.py Outdated
Comment thread transformer_engine/pytorch/triton_kernels/norms_common.py Outdated
Comment thread transformer_engine/pytorch/triton_kernels/norms_common.py
Comment thread transformer_engine/pytorch/triton_kernels/norms_common.py Outdated
Comment thread transformer_engine/pytorch/triton_kernels/rmsnorm.py
Comment thread transformer_engine/pytorch/triton_kernels/rmsnorm.py Outdated
@Micky774 Micky774 requested a review from alextmagro May 29, 2026 17:13
Copy link
Copy Markdown
Contributor

@alextmagro alextmagro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Comment thread transformer_engine/pytorch/triton_kernels/rmsnorm.py
@Micky774 Micky774 requested a review from aris134 May 29, 2026 19:17
Comment thread transformer_engine/pytorch/triton_kernels/rmsnorm.py Outdated
Copy link
Copy Markdown
Contributor

@aris134 aris134 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Comment thread README.rst Outdated
* NVTE_USE_CAST_TRANSPOSE_TRITON=1 can be used to enable cast transpose (bgrad) triton kernels;
* NVTE_USE_LAYERNORM_TRITON=1 can be used to enable layernorm triton kernels.
* NVTE_USE_RMSNORM_TRITON=1 can be used to enable rmsnorm triton kernels.
* NVTE_RMS_EXTERNAL_TRANSPOSE=0 disables external transpose in RMSNorm Triton kernels and
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not used in code

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ipanfilo Could you check if the comments has been addressed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!

@wenchenvincent
Copy link
Copy Markdown
Collaborator

@alextmagro @aris134 I saw you had approved the PR. For the inline comments, let's also resolve conversation if the comments has been addressed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants