Skip to content
3 changes: 2 additions & 1 deletion tests/pytorch/triton_kernels/test_norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,12 @@ def test_norm_triton(
zero_centered_gamma=zero_centered_gamma,

)
triton_bwd_outs = triton_bwd_func(*args["triton"])

if norm == "layer":
triton_bwd_outs = triton_bwd_func(*args["triton"])
dx_triton, dgamma_triton, dbeta_triton = triton_bwd_outs
elif norm == "rms":
triton_bwd_outs = triton_bwd_func(*args["triton"], autotune=autotune)
dx_triton, dgamma_triton = triton_bwd_outs
dbeta_triton = None

Expand Down
166 changes: 151 additions & 15 deletions transformer_engine/pytorch/triton_kernels/norms_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import triton
import triton.language as tl
import warnings
import transformer_engine_torch as tex

Expand All @@ -20,8 +21,78 @@
_rmsnorm_fwd_triton,
_rmsnorm_fwd_triton_impl,
_rmsnorm_bwd_triton,
_rmsnorm_bwd_triton_impl,
_rmsnorm_bwd_dg_reduce_triton,
_rmsnorm_bwd_dg_reduce_triton_impl,
)

# --------------------------------------------------------------------------- #
# External LDS-tiled byte transpose
#
# Produces the column-major fp8 transpose for the RMSNorm fwd path. The
# alternative -- having the main kernel emit `out_transpose_ptr + cols * stride
# + row_idx` strided stores -- is uncoalesced (1 byte/thread to a different
# cache line each) and bottlenecks every fp8_t shape, so RMSNorm always uses
# this kernel instead.
#
# This kernel reads a (BLOCK_M, BLOCK_N) tile coalesced from the row-major
# fp8 output, transposes it through LDS via `tl.trans`, and writes the
# (BLOCK_N, BLOCK_M) tile coalesced to the column-major transpose buffer.
#
# Operates on the raw uint8 storage so the fp8 dtype is irrelevant to
# correctness.
# --------------------------------------------------------------------------- #
@triton.jit
def _fp8_transpose_2d_impl(
src_ptr, # uint8 ptr, (n_rows, n_cols) row-major
dst_ptr, # uint8 ptr, (n_cols, n_rows) row-major
n_rows, n_cols,
src_stride, # element stride of src row dim (== n_cols when contig)
dst_stride, # element stride of dst row dim (== n_rows when contig)
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)

rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

# Coalesced read of (BLOCK_M, BLOCK_N) tile (innermost dim = cols).
src_offs = rm[:, None] * src_stride + cn[None, :]
src_mask = (rm[:, None] < n_rows) & (cn[None, :] < n_cols)
tile = tl.load(src_ptr + src_offs, mask=src_mask, other=0)

# LDS-staged transpose -> (BLOCK_N, BLOCK_M).
tile_t = tl.trans(tile)

# Coalesced write of (BLOCK_N, BLOCK_M) tile (innermost dim = rows).
dst_offs = cn[:, None] * dst_stride + rm[None, :]
dst_mask = (cn[:, None] < n_cols) & (rm[None, :] < n_rows)
tl.store(dst_ptr + dst_offs, tile_t, mask=dst_mask)


def _get_fp8_transpose_configs():
return [
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64}, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64}, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128}, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128}, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64}, num_warps=8),
]


_fp8_transpose_2d_triton = triton.autotune(
configs=_get_fp8_transpose_configs(),
key=['n_rows', 'n_cols'],
use_cuda_graph=True,
)(_fp8_transpose_2d_impl)

_fp8_transpose_kernels = {
True: _fp8_transpose_2d_triton,
False: _fp8_transpose_2d_impl,
}
from .layernorm import (
_layernorm_fwd_triton,
_layernorm_fwd_triton_impl,
Expand All @@ -41,6 +112,16 @@
False: _layernorm_fwd_triton_impl,
}
}

_rmsnorm_bwd_kernels = {
True: _rmsnorm_bwd_triton,
False: _rmsnorm_bwd_triton_impl,
}

_rmsnorm_bwd_dg_reduce_kernels = {
True: _rmsnorm_bwd_dg_reduce_triton,
False: _rmsnorm_bwd_dg_reduce_triton_impl,
}
# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd
def te_rmsnorm_fwd_triton(
input: torch.Tensor,
Expand Down Expand Up @@ -152,6 +233,7 @@ def _te_norm_fwd_triton(
out_transpose_ptr = None
out_transpose_stride = None
FP8_MAX = None
use_external_transpose = False
if IS_FP8:
MAKE_TRANSPOSE = quantizer.columnwise_usage
amax = (
Expand All @@ -170,8 +252,13 @@ def _te_norm_fwd_triton(
dtype=out._data.dtype, device=device
)
out._transpose_invalid = False
out_transpose_ptr = triton.reinterpret(out._transpose, tl_dtype)
out_transpose_stride = out._transpose.stride(0)
if kernel == 'rms':
# RMSNorm always uses the external LDS-tiled transpose kernel.
use_external_transpose = True
else:
# LayerNorm emits the transpose via in-kernel strided stores.
out_transpose_ptr = triton.reinterpret(out._transpose, tl_dtype)
out_transpose_stride = out._transpose.stride(0)

grid_fwd = lambda meta: (NUM_PRGMS,)
kernel_func = _norm_kernels[kernel][autotune]
Expand All @@ -189,19 +276,20 @@ def _te_norm_fwd_triton(
q_amax_ptr=amax,
q_scale_ptr=q_scale,
scale_inv_ptr=scale_inv_ptr,
out_transpose_ptr=out_transpose_ptr,
out_transpose_stride=out_transpose_stride,
ZERO_CENTERED_GAMMA=zero_centered_gamma,
BLOCK_SIZE=BLOCK_SIZE,
IS_FP8=IS_FP8,
FP8_MAX=FP8_MAX,
MAKE_TRANSPOSE=MAKE_TRANSPOSE,
)
if kernel == 'layer':
kwargs["APPLY_ATOMIC"]=APPLY_ATOMIC
kwargs["PERSISTENT"]=False # TODO: Improve persistent algo performance
kwargs["b_ptr"]=bias
kwargs["mean_ptr"]=mu
# LayerNorm emits the column-major fp8 copy via in-kernel strided stores.
kwargs["out_transpose_ptr"]=out_transpose_ptr
kwargs["out_transpose_stride"]=out_transpose_stride
kwargs["MAKE_TRANSPOSE"]=MAKE_TRANSPOSE
elif kernel == "rms":
kwargs["USE_BLOCKED"]=USE_BLOCKED
kwargs["NUM_PRGMS"]=NUM_PRGMS
Expand All @@ -216,6 +304,29 @@ def _te_norm_fwd_triton(

kernel_func[grid_fwd](**kwargs)

if use_external_transpose:
# out._data: (N rows, H cols) row-major uint8; out._transpose: (H, N).
transpose_kernel = _fp8_transpose_kernels[autotune]
if autotune:
grid_t = lambda meta: (
triton.cdiv(N, meta['BLOCK_M']),
triton.cdiv(H, meta['BLOCK_N']),
)
transpose_kernel[grid_t](
out._data, out._transpose,
N, H,
out._data.stride(0), out._transpose.stride(0),
)
else:
BLOCK_M, BLOCK_N = 64, 64
Comment thread
aris134 marked this conversation as resolved.
grid_t = (triton.cdiv(N, BLOCK_M), triton.cdiv(H, BLOCK_N))
transpose_kernel[grid_t](
out._data, out._transpose,
N, H,
out._data.stride(0), out._transpose.stride(0),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
)

# Reduce and find amax if "not APPLY_ATOMIC" is True for layernorm.
if IS_FP8 and not APPLY_ATOMIC:
_layernorm_fwd_reduce_triton[(triton.cdiv(N, ATOMIC_REDUCTION_BLOCK_SIZE),)](
Expand All @@ -234,7 +345,7 @@ def _te_norm_fwd_triton(


# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_bwd
def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma):
def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma, autotune: bool = True):
# may take non-contiguous inputs
dz_ = dz.contiguous()
x_ = x.contiguous()
Expand All @@ -248,25 +359,50 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma):
blk_size = block_size(x_)
USE_BLOCKED = use_blocked(x_)
NUM_PRGMS = num_programs(x_, sm_margin)
need_reduction = N > 1
dg_tmp_rows = x_.shape[0] if use_blocked(x_) else num_programs(x_, sm_margin)
need_reduction = NUM_PRGMS > 1
dg_tmp_rows = M if USE_BLOCKED else NUM_PRGMS
dg_tmp = torch.empty(dg_tmp_rows, N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None

input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(0) * x_.dtype.itemsize % 16 == 0)
grad_output_aligned_16 = (dz_.data_ptr() % 16 == 0) and (dz_.stride(0) * dz_.dtype.itemsize % 16 == 0)
dx_aligned_16 = (dx.data_ptr() % 16 == 0) and (dx.stride(0) * dx.dtype.itemsize % 16 == 0)
dg_target = dg_tmp if need_reduction else dgamma
dg_aligned_16 = (dg_target.data_ptr() % 16 == 0) and (dg_target.stride(0) * dg_target.dtype.itemsize % 16 == 0)

grid_bwd = lambda meta: (NUM_PRGMS, )
_rmsnorm_bwd_triton[grid_bwd](dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma,
x_.stride(0), dz_.stride(0), M, N, zero_centered_gamma, blk_size,
USE_BLOCKED, NUM_PRGMS, input_aligned_16, grad_output_aligned_16,
dx_aligned_16, dg_aligned_16, num_warps=8)
bwd_kernel = _rmsnorm_bwd_kernels[autotune]
bwd_kwargs = dict(
Comment thread
alextmagro marked this conversation as resolved.
n_rows=M, n_cols=N,
ZERO_CENTERED_GAMMA=zero_centered_gamma,
BLOCK_SIZE=blk_size,
USE_BLOCKED=USE_BLOCKED, NUM_PRGMS=NUM_PRGMS,
INPUT_ALIGNED_16=input_aligned_16,
GRAD_OUTPUT_ALIGNED_16=grad_output_aligned_16,
DX_ALIGNED_16=dx_aligned_16,
DG_ALIGNED_16=dg_aligned_16,
)
if not autotune:
bwd_kwargs["num_warps"] = 8
bwd_kernel[grid_bwd](
dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma,
x_.stride(0), dz_.stride(0),
**bwd_kwargs,
)

if need_reduction:
grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
_rmsnorm_bwd_dg_reduce_triton[grid_reduce](dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1],
BLOCK_SIZE_M=128, BLOCK_SIZE_N=64)
reduce_kernel = _rmsnorm_bwd_dg_reduce_kernels[autotune]
if autotune:
grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
reduce_kernel[grid_reduce](
dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1],
)
else:
BLOCK_SIZE_M, BLOCK_SIZE_N = 128, 64
grid_reduce = (triton.cdiv(N, BLOCK_SIZE_N),)
reduce_kernel[grid_reduce](
dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1],
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N,
)

return dx, dgamma

Expand Down
Loading
Loading