diff --git a/tests/pytorch/triton_kernels/test_norms.py b/tests/pytorch/triton_kernels/test_norms.py index a4f11ba36..fdd6192af 100644 --- a/tests/pytorch/triton_kernels/test_norms.py +++ b/tests/pytorch/triton_kernels/test_norms.py @@ -298,11 +298,12 @@ def test_norm_triton( zero_centered_gamma=zero_centered_gamma, ) - triton_bwd_outs = triton_bwd_func(*args["triton"]) if norm == "layer": + triton_bwd_outs = triton_bwd_func(*args["triton"]) dx_triton, dgamma_triton, dbeta_triton = triton_bwd_outs elif norm == "rms": + triton_bwd_outs = triton_bwd_func(*args["triton"], autotune=autotune) dx_triton, dgamma_triton = triton_bwd_outs dbeta_triton = None diff --git a/transformer_engine/pytorch/triton_kernels/norms_common.py b/transformer_engine/pytorch/triton_kernels/norms_common.py index ed4002f2c..3c770ec8f 100644 --- a/transformer_engine/pytorch/triton_kernels/norms_common.py +++ b/transformer_engine/pytorch/triton_kernels/norms_common.py @@ -3,6 +3,7 @@ import torch import triton +import triton.language as tl import warnings import transformer_engine_torch as tex @@ -20,8 +21,78 @@ _rmsnorm_fwd_triton, _rmsnorm_fwd_triton_impl, _rmsnorm_bwd_triton, + _rmsnorm_bwd_triton_impl, _rmsnorm_bwd_dg_reduce_triton, + _rmsnorm_bwd_dg_reduce_triton_impl, ) + +# --------------------------------------------------------------------------- # +# External LDS-tiled byte transpose +# +# Produces the column-major fp8 transpose for the RMSNorm fwd path. The +# alternative -- having the main kernel emit `out_transpose_ptr + cols * stride +# + row_idx` strided stores -- is uncoalesced (1 byte/thread to a different +# cache line each) and bottlenecks every fp8_t shape, so RMSNorm always uses +# this kernel instead. +# +# This kernel reads a (BLOCK_M, BLOCK_N) tile coalesced from the row-major +# fp8 output, transposes it through LDS via `tl.trans`, and writes the +# (BLOCK_N, BLOCK_M) tile coalesced to the column-major transpose buffer. +# +# Operates on the raw uint8 storage so the fp8 dtype is irrelevant to +# correctness. +# --------------------------------------------------------------------------- # +@triton.jit +def _fp8_transpose_2d_impl( + src_ptr, # uint8 ptr, (n_rows, n_cols) row-major + dst_ptr, # uint8 ptr, (n_cols, n_rows) row-major + n_rows, n_cols, + src_stride, # element stride of src row dim (== n_cols when contig) + dst_stride, # element stride of dst row dim (== n_rows when contig) + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # Coalesced read of (BLOCK_M, BLOCK_N) tile (innermost dim = cols). + src_offs = rm[:, None] * src_stride + cn[None, :] + src_mask = (rm[:, None] < n_rows) & (cn[None, :] < n_cols) + tile = tl.load(src_ptr + src_offs, mask=src_mask, other=0) + + # LDS-staged transpose -> (BLOCK_N, BLOCK_M). + tile_t = tl.trans(tile) + + # Coalesced write of (BLOCK_N, BLOCK_M) tile (innermost dim = rows). + dst_offs = cn[:, None] * dst_stride + rm[None, :] + dst_mask = (cn[:, None] < n_cols) & (rm[None, :] < n_rows) + tl.store(dst_ptr + dst_offs, tile_t, mask=dst_mask) + + +def _get_fp8_transpose_configs(): + return [ + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64}, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64}, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128}, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128}, num_warps=8), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64}, num_warps=8), + ] + + +_fp8_transpose_2d_triton = triton.autotune( + configs=_get_fp8_transpose_configs(), + key=['n_rows', 'n_cols'], + use_cuda_graph=True, +)(_fp8_transpose_2d_impl) + +_fp8_transpose_kernels = { + True: _fp8_transpose_2d_triton, + False: _fp8_transpose_2d_impl, +} from .layernorm import ( _layernorm_fwd_triton, _layernorm_fwd_triton_impl, @@ -41,6 +112,16 @@ False: _layernorm_fwd_triton_impl, } } + +_rmsnorm_bwd_kernels = { + True: _rmsnorm_bwd_triton, + False: _rmsnorm_bwd_triton_impl, +} + +_rmsnorm_bwd_dg_reduce_kernels = { + True: _rmsnorm_bwd_dg_reduce_triton, + False: _rmsnorm_bwd_dg_reduce_triton_impl, +} # triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd def te_rmsnorm_fwd_triton( input: torch.Tensor, @@ -152,6 +233,7 @@ def _te_norm_fwd_triton( out_transpose_ptr = None out_transpose_stride = None FP8_MAX = None + use_external_transpose = False if IS_FP8: MAKE_TRANSPOSE = quantizer.columnwise_usage amax = ( @@ -170,8 +252,13 @@ def _te_norm_fwd_triton( dtype=out._data.dtype, device=device ) out._transpose_invalid = False - out_transpose_ptr = triton.reinterpret(out._transpose, tl_dtype) - out_transpose_stride = out._transpose.stride(0) + if kernel == 'rms': + # RMSNorm always uses the external LDS-tiled transpose kernel. + use_external_transpose = True + else: + # LayerNorm emits the transpose via in-kernel strided stores. + out_transpose_ptr = triton.reinterpret(out._transpose, tl_dtype) + out_transpose_stride = out._transpose.stride(0) grid_fwd = lambda meta: (NUM_PRGMS,) kernel_func = _norm_kernels[kernel][autotune] @@ -189,19 +276,20 @@ def _te_norm_fwd_triton( q_amax_ptr=amax, q_scale_ptr=q_scale, scale_inv_ptr=scale_inv_ptr, - out_transpose_ptr=out_transpose_ptr, - out_transpose_stride=out_transpose_stride, ZERO_CENTERED_GAMMA=zero_centered_gamma, BLOCK_SIZE=BLOCK_SIZE, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, - MAKE_TRANSPOSE=MAKE_TRANSPOSE, ) if kernel == 'layer': kwargs["APPLY_ATOMIC"]=APPLY_ATOMIC kwargs["PERSISTENT"]=False # TODO: Improve persistent algo performance kwargs["b_ptr"]=bias kwargs["mean_ptr"]=mu + # LayerNorm emits the column-major fp8 copy via in-kernel strided stores. + kwargs["out_transpose_ptr"]=out_transpose_ptr + kwargs["out_transpose_stride"]=out_transpose_stride + kwargs["MAKE_TRANSPOSE"]=MAKE_TRANSPOSE elif kernel == "rms": kwargs["USE_BLOCKED"]=USE_BLOCKED kwargs["NUM_PRGMS"]=NUM_PRGMS @@ -216,6 +304,29 @@ def _te_norm_fwd_triton( kernel_func[grid_fwd](**kwargs) + if use_external_transpose: + # out._data: (N rows, H cols) row-major uint8; out._transpose: (H, N). + transpose_kernel = _fp8_transpose_kernels[autotune] + if autotune: + grid_t = lambda meta: ( + triton.cdiv(N, meta['BLOCK_M']), + triton.cdiv(H, meta['BLOCK_N']), + ) + transpose_kernel[grid_t]( + out._data, out._transpose, + N, H, + out._data.stride(0), out._transpose.stride(0), + ) + else: + BLOCK_M, BLOCK_N = 64, 64 + grid_t = (triton.cdiv(N, BLOCK_M), triton.cdiv(H, BLOCK_N)) + transpose_kernel[grid_t]( + out._data, out._transpose, + N, H, + out._data.stride(0), out._transpose.stride(0), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + ) + # Reduce and find amax if "not APPLY_ATOMIC" is True for layernorm. if IS_FP8 and not APPLY_ATOMIC: _layernorm_fwd_reduce_triton[(triton.cdiv(N, ATOMIC_REDUCTION_BLOCK_SIZE),)]( @@ -234,7 +345,7 @@ def _te_norm_fwd_triton( # triton drop-in replacement for transformer_engine::pytorch::rmsnorm_bwd -def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): +def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma, autotune: bool = True): # may take non-contiguous inputs dz_ = dz.contiguous() x_ = x.contiguous() @@ -248,8 +359,8 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): blk_size = block_size(x_) USE_BLOCKED = use_blocked(x_) NUM_PRGMS = num_programs(x_, sm_margin) - need_reduction = N > 1 - dg_tmp_rows = x_.shape[0] if use_blocked(x_) else num_programs(x_, sm_margin) + need_reduction = NUM_PRGMS > 1 + dg_tmp_rows = M if USE_BLOCKED else NUM_PRGMS dg_tmp = torch.empty(dg_tmp_rows, N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(0) * x_.dtype.itemsize % 16 == 0) @@ -257,16 +368,41 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): dx_aligned_16 = (dx.data_ptr() % 16 == 0) and (dx.stride(0) * dx.dtype.itemsize % 16 == 0) dg_target = dg_tmp if need_reduction else dgamma dg_aligned_16 = (dg_target.data_ptr() % 16 == 0) and (dg_target.stride(0) * dg_target.dtype.itemsize % 16 == 0) + grid_bwd = lambda meta: (NUM_PRGMS, ) - _rmsnorm_bwd_triton[grid_bwd](dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, - x_.stride(0), dz_.stride(0), M, N, zero_centered_gamma, blk_size, - USE_BLOCKED, NUM_PRGMS, input_aligned_16, grad_output_aligned_16, - dx_aligned_16, dg_aligned_16, num_warps=8) + bwd_kernel = _rmsnorm_bwd_kernels[autotune] + bwd_kwargs = dict( + n_rows=M, n_cols=N, + ZERO_CENTERED_GAMMA=zero_centered_gamma, + BLOCK_SIZE=blk_size, + USE_BLOCKED=USE_BLOCKED, NUM_PRGMS=NUM_PRGMS, + INPUT_ALIGNED_16=input_aligned_16, + GRAD_OUTPUT_ALIGNED_16=grad_output_aligned_16, + DX_ALIGNED_16=dx_aligned_16, + DG_ALIGNED_16=dg_aligned_16, + ) + if not autotune: + bwd_kwargs["num_warps"] = 8 + bwd_kernel[grid_bwd]( + dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, + x_.stride(0), dz_.stride(0), + **bwd_kwargs, + ) if need_reduction: - grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] - _rmsnorm_bwd_dg_reduce_triton[grid_reduce](dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1], - BLOCK_SIZE_M=128, BLOCK_SIZE_N=64) + reduce_kernel = _rmsnorm_bwd_dg_reduce_kernels[autotune] + if autotune: + grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] + reduce_kernel[grid_reduce]( + dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1], + ) + else: + BLOCK_SIZE_M, BLOCK_SIZE_N = 128, 64 + grid_reduce = (triton.cdiv(N, BLOCK_SIZE_N),) + reduce_kernel[grid_reduce]( + dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1], + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, + ) return dx, dgamma diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index 5ecb48eb7..fc82b7f9d 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -24,22 +24,16 @@ def _rmsnorm_fwd_triton_impl( q_amax_ptr, q_scale_ptr, scale_inv_ptr, - out_transpose_ptr, - out_transpose_stride, ZERO_CENTERED_GAMMA: tl.constexpr, BLOCK_SIZE: tl.constexpr, USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, - MAKE_TRANSPOSE: tl.constexpr, INPUT_ALIGNED_16: tl.constexpr, OUTPUT_ALIGNED_16: tl.constexpr, ): - # Enable the transpose cache only in FP8 mode. - tl.static_assert(not MAKE_TRANSPOSE or IS_FP8, "Transpose cache requires fp8 data type.") - row_start = tl.program_id(0) col_offsets = tl.arange(0, BLOCK_SIZE) # as older version Triton doesn't support tl.assume and BUFF OPS, comment out for now @@ -108,9 +102,6 @@ def _rmsnorm_fwd_triton_impl( amax = tl.maximum(amax, amax_temp) rms_norm = rms_norm * scale rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX) - if MAKE_TRANSPOSE: - output_t_ptrs = out_transpose_ptr + cols * out_transpose_stride + row_idx - tl.store(output_t_ptrs, rms_norm.to(output_type)) tl.store(output_ptrs, rms_norm.to(output_type)) # Handle remainder @@ -133,29 +124,28 @@ def _rmsnorm_fwd_triton_impl( amax = tl.maximum(amax, amax_temp) rms_norm = rms_norm * scale rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX) - if MAKE_TRANSPOSE: - output_t_ptrs = out_transpose_ptr + cols * out_transpose_stride + row_idx - tl.store(output_t_ptrs, rms_norm.to(output_type), mask=mask) tl.store(output_ptrs, rms_norm.to(output_type), mask=mask) else: mask = col_offsets < n_cols + # gamma is invariant across rows -- load + ZERO_CENTERED adjustment once per program. + g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + if (ZERO_CENTERED_GAMMA): + g += 1 + inv_n_cols = 1.0 / n_cols for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2): input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets if INPUT_ALIGNED_16: input_ptrs = tl.multiple_of(input_ptrs, (16, )) row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) - g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) row_norm = row * row row_norm = tl.sum(row_norm, axis=-1) - norm_factor = tl.math.rsqrt((row_norm / n_cols) + epsilon) + norm_factor = tl.math.rsqrt(row_norm * inv_n_cols + epsilon) # Store rsigma (norm_factor) rsigma_output_ptr = rsigma_ptr + row_idx tl.store(rsigma_output_ptr, norm_factor) - if (ZERO_CENTERED_GAMMA): - g += 1 rms_norm = row * norm_factor * g output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets @@ -166,9 +156,6 @@ def _rmsnorm_fwd_triton_impl( amax = tl.maximum(amax, amax_temp) rms_norm = rms_norm * scale rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX) - if MAKE_TRANSPOSE: - output_t_ptrs = out_transpose_ptr + col_offsets * out_transpose_stride + row_idx - tl.store(output_t_ptrs, rms_norm.to(output_type), mask=mask) tl.store(output_ptrs, rms_norm.to(output_type), mask=mask) if IS_FP8: tl.atomic_max(q_amax_ptr, amax, sem="relaxed") @@ -181,29 +168,27 @@ def _rmsnorm_fwd_triton_impl( _rmsnorm_fwd_triton = autotune_dec(_rmsnorm_fwd_triton_impl) @triton.jit -def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, dg_ptr, input_row_stride, output_row_stride, +def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, dg_ptr, input_row_stride, output_row_stride, n_rows, n_cols, ZERO_CENTERED_GAMMA: tl.constexpr, BLOCK_SIZE: tl.constexpr, USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr, INPUT_ALIGNED_16: tl.constexpr, GRAD_OUTPUT_ALIGNED_16: tl.constexpr, DX_ALIGNED_16: tl.constexpr, DG_ALIGNED_16: tl.constexpr): row_start = tl.program_id(0) col_offsets = tl.arange(0, BLOCK_SIZE) + inv_n_cols = 1.0 / n_cols # tl.assume(input_row_stride >= 0) # tl.assume(output_row_stride >= 0) # tl.assume(row_start >= 0) if USE_BLOCKED: - for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=1): + for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2): row_input_ptr = input_ptr + row_idx * input_row_stride row_grad_output_ptr = grad_output_ptr + row_idx * output_row_stride row_dx_ptr = dx_ptr + row_idx * input_row_stride - row_dg_ptr = dg_ptr + row_idx * input_row_stride + row_dg_ptr = dg_ptr + row_idx * n_cols # Compute gradients sum of all colums for each row n_cols_blks = tl.cdiv(n_cols, BLOCK_SIZE) - 1 - # older version of triton doesn't accept below init - # comment out for now to make it compatible with triton 3.1 - # grad_sum: tl.float32 = 0.0 grad_sum = 0.0 for blk_idx in tl.range(0, n_cols_blks, num_stages=2): cols = blk_idx * BLOCK_SIZE + col_offsets @@ -238,6 +223,9 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d # Load r_sigma norm_factor = tl.load(rsigma_ptr + row_idx).to(tl.float32) + # Precomputed per-row invariant: c = nf*nf * grad_sum / n_cols + # used in calculating dx = nf * (dz*g - c*x) + c_scalar = norm_factor * norm_factor * grad_sum * inv_n_cols for blk_idx in tl.range(0, n_cols_blks, num_stages=2): cols = blk_idx * BLOCK_SIZE + col_offsets @@ -256,8 +244,7 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d g = tl.load(g_ptrs).to(tl.float32) if (ZERO_CENTERED_GAMMA): g += 1. - grad_input = grad_output * norm_factor * g - (norm_factor * norm_factor * norm_factor) * x * (grad_sum / - n_cols) + grad_input = norm_factor * (grad_output * g - c_scalar * x) dx_ptrs = row_dx_ptr + cols if DX_ALIGNED_16: @@ -268,7 +255,7 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d dg_ptrs = row_dg_ptr + cols if DG_ALIGNED_16: dg_ptrs = tl.multiple_of(dg_ptrs, (16, )) - tl.store(dg_ptrs, dg.to(tl.float32)) + tl.store(dg_ptrs, dg) # Handle remainder cols = n_cols_blks * BLOCK_SIZE + col_offsets @@ -282,8 +269,7 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d g = tl.load(g_ptrs, mask=mask, other=0.0).to(tl.float32) if (ZERO_CENTERED_GAMMA): g += 1. - grad_input = grad_output * norm_factor * g - (norm_factor * norm_factor * norm_factor) * x * (grad_sum / - n_cols) + grad_input = norm_factor * (grad_output * g - c_scalar * x) dx_ptrs = row_dx_ptr + cols if DX_ALIGNED_16: @@ -294,12 +280,16 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d dg_ptrs = row_dg_ptr + cols if DG_ALIGNED_16: dg_ptrs = tl.multiple_of(dg_ptrs, (16, )) - tl.store(dg_ptrs, dg.to(tl.float32), mask=mask) + tl.store(dg_ptrs, dg, mask=mask) else: mask = col_offsets < n_cols dg_col_redux = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + if (ZERO_CENTERED_GAMMA): + g += 1. + for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2): input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets grad_output_ptrs = grad_output_ptr + row_idx * output_row_stride + col_offsets @@ -314,25 +304,31 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d x = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) grad_output = tl.load(grad_output_ptrs, mask=mask, other=0.0).to(tl.float32) - g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) - if (ZERO_CENTERED_GAMMA): - g += 1. norm_factor = tl.load(rsigma_ptr + row_idx).to(tl.float32) grad_sum = tl.sum(grad_output * x * g, axis=0) + c_scalar = norm_factor * norm_factor * grad_sum * inv_n_cols - grad_input = grad_output * norm_factor * g - (norm_factor * norm_factor * norm_factor) * x * (grad_sum / - n_cols) + grad_input = norm_factor * (grad_output * g - c_scalar * x) tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=mask) dg = grad_output * x * norm_factor dg_col_redux += dg.to(tl.float32) - tl.store(dg_ptr + tl.program_id(0) * input_row_stride + col_offsets, dg_col_redux, mask=mask) + tl.store(dg_ptr + row_start * n_cols + col_offsets, dg_col_redux, mask=mask) + + +# Autotune wrapper. Mirrors the fwd autotune layout so callers can toggle +# autotune via the same flag. +_rmsnorm_bwd_triton = triton.autotune( + configs=get_autotune_config(), + key=['n_rows', 'n_cols'], + use_cuda_graph=True, +)(_rmsnorm_bwd_triton_impl) @triton.jit -def _rmsnorm_bwd_dg_reduce_triton(dg_in_ptr, dg_out_ptr, dg_in_stride, n_rows, n_cols, BLOCK_SIZE_M: tl.constexpr, +def _rmsnorm_bwd_dg_reduce_triton_impl(dg_in_ptr, dg_out_ptr, dg_in_stride, n_rows, n_cols, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): # we want parallelism in N direction # if N is small, we will just use one CU, @@ -349,3 +345,23 @@ def _rmsnorm_bwd_dg_reduce_triton(dg_in_ptr, dg_out_ptr, dg_in_stride, n_rows, n sum_dg = tl.sum(acc, axis=0) tl.store(dg_out_ptr + cols, sum_dg.to(dg_out_ptr.type.element_ty), mask=cols < n_cols) + +def _get_dg_reduce_configs(): + # n_rows is NUM_PRGMS so the M dimension is small. + # The reduce kernel is <1% of bwd cost, so a tight 6-config sweep is plenty; + # bigger sweeps just pay first-call compile tax for marginal gain. + return [ + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64}, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128}, num_warps=8), + ] + + +_rmsnorm_bwd_dg_reduce_triton = triton.autotune( + configs=_get_dg_reduce_configs(), + key=['n_rows', 'n_cols'], + use_cuda_graph=True, +)(_rmsnorm_bwd_dg_reduce_triton_impl)