diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu index 835923828..460d6236f 100644 --- a/transformer_engine/common/util/padding.cu +++ b/transformer_engine/common/util/padding.cu @@ -1,4 +1,6 @@ /************************************************************************* +* This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -13,13 +15,42 @@ #include "../common.h" #include "../utils.cuh" +#ifdef __HIP_PLATFORM_AMD__ +#include "rocm_device_utils.cuh" // for rocm_upper_bound(), NTVec +#endif namespace transformer_engine { namespace { +#ifdef __HIP_PLATFORM_AMD__ +// Non-temporal store helper: uses NT store for full aligned vectors, +// falls back to element-wise for partial/unaligned cases. +// Note: NT loads were also benchmarked but hurt performance. +template +__device__ __forceinline__ void nt_store_to_elts(const Vec& v, + Type* ptr, int count) { + constexpr size_t BYTES = nvec * sizeof(Type); + if (count == nvec && reinterpret_cast(ptr) % BYTES == 0) { + NTVec nt; +#pragma unroll + for (int i = 0; i < nvec; i++) nt.val[i] = v.data.elt[i]; + nt.nt_store(ptr); + } else { +#pragma unroll + for (int i = 0; i < nvec; i++) { + if (i < count) ptr[i] = v.data.elt[i]; + } + } +} +#endif + // Parameters to tune +#ifdef __HIP_PLATFORM_AMD__ +constexpr int n_warps_per_tile = 16; +#else constexpr int n_warps_per_tile = 4; +#endif constexpr int threads_per_block = THREADS_PER_WARP * n_warps_per_tile; constexpr int desired_load_store_size = 8; constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB @@ -65,15 +96,22 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile; // Find tensor corresponding to block +#ifdef __HIP_PLATFORM_AMD__ + const int tensor_id = rocm_upper_bound(args.block_range, args.num_tensors, bid); +#else int tensor_id = 0; while (args.block_range[tensor_id + 1] <= bid) { ++tensor_id; } +#endif const Type* input = reinterpret_cast(args.input_list[tensor_id]); Type* output = reinterpret_cast(args.output_list[tensor_id]); const int num_rows = args.num_rows_list[tensor_id]; const int padded_num_rows = args.padded_num_rows_list[tensor_id]; const int row_length = args.row_length_list[tensor_id]; +#ifdef __HIP_PLATFORM_AMD__ + const bool inplace = (input == output); +#endif // Find position of tile within tensor const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; @@ -83,6 +121,36 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP const int tile_row = tile_id_m * tile_dim_m; const int tile_col = tile_id_n * tile_dim_n; +#ifdef __HIP_PLATFORM_AMD__ + // Process subtiles with vectorized loads/stores +#pragma unroll + for (int iter = 0; iter < n_iterations; ++iter) { + const int i1 = tidy + iter * bdimy; + const int j1 = tidx; + const int col = tile_col + j1 * nvec; + const int remaining = row_length - col; + const int valid_cols = remaining > 0 ? min(remaining, nvec) : 0; +#pragma unroll + for (int i2 = 0; i2 < nvec; ++i2) { + const int row = tile_row + i1 * nvec + i2; + if (row < num_rows) { + // Valid data row: skip copy when in-place + if (!inplace) { + const size_t offset = static_cast(row) * row_length + col; + Vec v; + v.load_from_elts(input, offset, valid_cols); + nt_store_to_elts(v, output + offset, valid_cols); + } + } else if (row < padded_num_rows) { + // Padding row: fill with zeros + const size_t offset = static_cast(row) * row_length + col; + Vec v; + v.clear(); + nt_store_to_elts(v, output + offset, valid_cols); + } + } + } +#else // !__HIP_PLATFORM_AMD__ // Load input and store to registers // Note: Each thread loads n_iterations subtiles, casts to output // type, and transposes in registers. @@ -125,6 +193,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP } } } +#endif // __HIP_PLATFORM_AMD__ } template @@ -150,14 +219,21 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile; // Find tensor corresponding to block +#ifdef __HIP_PLATFORM_AMD__ + const int tensor_id = rocm_upper_bound(args.block_range, args.num_tensors, bid); +#else int tensor_id = 0; while (args.block_range[tensor_id + 1] <= bid) { ++tensor_id; } +#endif const Type* input = reinterpret_cast(args.input_list[tensor_id]); Type* output = reinterpret_cast(args.output_list[tensor_id]); const int num_rows = args.num_rows_list[tensor_id]; const int row_length = args.row_length_list[tensor_id]; +#ifdef __HIP_PLATFORM_AMD__ + const bool inplace = (input == output); +#endif // Find position of tile within tensor const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; @@ -167,6 +243,27 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult const int tile_row = tile_id_m * tile_dim_m; const int tile_col = tile_id_n * tile_dim_n; +#ifdef __HIP_PLATFORM_AMD__ + // Process subtiles with vectorized loads/stores +#pragma unroll + for (int iter = 0; iter < n_iterations; ++iter) { + const int i1 = tidy + iter * bdimy; + const int j1 = tidx; + const int col = tile_col + j1 * nvec; + const int remaining = row_length - col; + const int valid_cols = remaining > 0 ? min(remaining, nvec) : 0; +#pragma unroll + for (int i2 = 0; i2 < nvec; ++i2) { + const int row = tile_row + i1 * nvec + i2; + if (row < num_rows && !inplace) { + const size_t offset = static_cast(row) * row_length + col; + Vec v; + v.load_from_elts(input, offset, valid_cols); + nt_store_to_elts(v, output + offset, valid_cols); + } + } + } +#else // !__HIP_PLATFORM_AMD__ // Load input and store to registers // Note: Each thread loads n_iterations subtiles, casts to output // type, and transposes in registers. @@ -202,6 +299,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult } } } +#endif // __HIP_PLATFORM_AMD__ } } // namespace diff --git a/transformer_engine/common/util/rocm_device_utils.cuh b/transformer_engine/common/util/rocm_device_utils.cuh index 0d2b4c658..89c49b533 100644 --- a/transformer_engine/common/util/rocm_device_utils.cuh +++ b/transformer_engine/common/util/rocm_device_utils.cuh @@ -118,6 +118,23 @@ __device__ __forceinline__ void rocm_atomicMaxFloat(float *addr, float val) { atomicMax(reinterpret_cast(addr), __float_as_int(val)); } +// Binary search on a sorted array. +// Returns the largest index i in [0, n) such that arr[i] <= val. +// Precondition: arr is sorted in non-decreasing order and arr[0] <= val. +template +__device__ __forceinline__ int rocm_upper_bound(const T* arr, int n, T val) { + int lo = 0, hi = n - 1; + while (lo < hi) { + int mid = (lo + hi + 1) / 2; + if (arr[mid] <= val) { + lo = mid; + } else { + hi = mid - 1; + } + } + return lo; +} + template __device__ __forceinline__ float rocm_block_reduce_max(float val, int warp_id) { __shared__ float staging[WARPS];