Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions transformer_engine/common/util/padding.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand All @@ -13,13 +15,42 @@

#include "../common.h"
#include "../utils.cuh"
#ifdef __HIP_PLATFORM_AMD__
#include "rocm_device_utils.cuh" // for rocm_upper_bound(), NTVec
#endif

namespace transformer_engine {

namespace {

#ifdef __HIP_PLATFORM_AMD__
// Non-temporal store helper: uses NT store for full aligned vectors,
// falls back to element-wise for partial/unaligned cases.
// Note: NT loads were also benchmarked but hurt performance.
template <uint32_t nvec, typename Type>
__device__ __forceinline__ void nt_store_to_elts(const Vec<Type, nvec>& v,
Type* ptr, int count) {
constexpr size_t BYTES = nvec * sizeof(Type);
if (count == nvec && reinterpret_cast<uint64_t>(ptr) % BYTES == 0) {
NTVec<Type, nvec> nt;
#pragma unroll
for (int i = 0; i < nvec; i++) nt.val[i] = v.data.elt[i];
nt.nt_store(ptr);
} else {
#pragma unroll
for (int i = 0; i < nvec; i++) {
if (i < count) ptr[i] = v.data.elt[i];
}
}
}
#endif

// Parameters to tune
#ifdef __HIP_PLATFORM_AMD__
constexpr int n_warps_per_tile = 16;
#else
constexpr int n_warps_per_tile = 4;
#endif
constexpr int threads_per_block = THREADS_PER_WARP * n_warps_per_tile;
constexpr int desired_load_store_size = 8;
constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB
Expand Down Expand Up @@ -65,15 +96,22 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile;

// Find tensor corresponding to block
#ifdef __HIP_PLATFORM_AMD__
const int tensor_id = rocm_upper_bound(args.block_range, args.num_tensors, bid);
#else
int tensor_id = 0;
while (args.block_range[tensor_id + 1] <= bid) {
++tensor_id;
}
#endif
const Type* input = reinterpret_cast<const Type*>(args.input_list[tensor_id]);
Type* output = reinterpret_cast<Type*>(args.output_list[tensor_id]);
const int num_rows = args.num_rows_list[tensor_id];
const int padded_num_rows = args.padded_num_rows_list[tensor_id];
const int row_length = args.row_length_list[tensor_id];
#ifdef __HIP_PLATFORM_AMD__
const bool inplace = (input == output);
#endif

// Find position of tile within tensor
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
Expand All @@ -83,6 +121,36 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
const int tile_row = tile_id_m * tile_dim_m;
const int tile_col = tile_id_n * tile_dim_n;

#ifdef __HIP_PLATFORM_AMD__
// Process subtiles with vectorized loads/stores
#pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidy + iter * bdimy;
const int j1 = tidx;
const int col = tile_col + j1 * nvec;
const int remaining = row_length - col;
const int valid_cols = remaining > 0 ? min(remaining, nvec) : 0;
#pragma unroll
for (int i2 = 0; i2 < nvec; ++i2) {
const int row = tile_row + i1 * nvec + i2;
if (row < num_rows) {
// Valid data row: skip copy when in-place
if (!inplace) {
const size_t offset = static_cast<size_t>(row) * row_length + col;
Vec v;
v.load_from_elts(input, offset, valid_cols);
nt_store_to_elts(v, output + offset, valid_cols);
}
} else if (row < padded_num_rows) {
// Padding row: fill with zeros
const size_t offset = static_cast<size_t>(row) * row_length + col;
Vec v;
v.clear();
nt_store_to_elts(v, output + offset, valid_cols);
}
}
}
#else // !__HIP_PLATFORM_AMD__
// Load input and store to registers
// Note: Each thread loads n_iterations subtiles, casts to output
// type, and transposes in registers.
Expand Down Expand Up @@ -125,6 +193,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
}
}
}
#endif // __HIP_PLATFORM_AMD__
}

template <int nvec, typename Type>
Expand All @@ -150,14 +219,21 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile;

// Find tensor corresponding to block
#ifdef __HIP_PLATFORM_AMD__
const int tensor_id = rocm_upper_bound(args.block_range, args.num_tensors, bid);
#else
int tensor_id = 0;
while (args.block_range[tensor_id + 1] <= bid) {
++tensor_id;
}
#endif
const Type* input = reinterpret_cast<const Type*>(args.input_list[tensor_id]);
Type* output = reinterpret_cast<Type*>(args.output_list[tensor_id]);
const int num_rows = args.num_rows_list[tensor_id];
const int row_length = args.row_length_list[tensor_id];
#ifdef __HIP_PLATFORM_AMD__
const bool inplace = (input == output);
#endif

// Find position of tile within tensor
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
Expand All @@ -167,6 +243,27 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
const int tile_row = tile_id_m * tile_dim_m;
const int tile_col = tile_id_n * tile_dim_n;

#ifdef __HIP_PLATFORM_AMD__
// Process subtiles with vectorized loads/stores
#pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidy + iter * bdimy;
const int j1 = tidx;
const int col = tile_col + j1 * nvec;
const int remaining = row_length - col;
const int valid_cols = remaining > 0 ? min(remaining, nvec) : 0;
#pragma unroll
for (int i2 = 0; i2 < nvec; ++i2) {
const int row = tile_row + i1 * nvec + i2;
if (row < num_rows && !inplace) {
const size_t offset = static_cast<size_t>(row) * row_length + col;
Vec v;
v.load_from_elts(input, offset, valid_cols);
nt_store_to_elts(v, output + offset, valid_cols);
}
}
}
#else // !__HIP_PLATFORM_AMD__
// Load input and store to registers
// Note: Each thread loads n_iterations subtiles, casts to output
// type, and transposes in registers.
Expand Down Expand Up @@ -202,6 +299,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
}
}
}
#endif // __HIP_PLATFORM_AMD__
}

} // namespace
Expand Down
17 changes: 17 additions & 0 deletions transformer_engine/common/util/rocm_device_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,23 @@ __device__ __forceinline__ void rocm_atomicMaxFloat(float *addr, float val) {
atomicMax(reinterpret_cast<int*>(addr), __float_as_int(val));
}

// Binary search on a sorted array.
// Returns the largest index i in [0, n) such that arr[i] <= val.
// Precondition: arr is sorted in non-decreasing order and arr[0] <= val.
template <typename T>
__device__ __forceinline__ int rocm_upper_bound(const T* arr, int n, T val) {
int lo = 0, hi = n - 1;
while (lo < hi) {
int mid = (lo + hi + 1) / 2;
if (arr[mid] <= val) {
lo = mid;
} else {
hi = mid - 1;
}
}
return lo;
}

template <int WARPS>
__device__ __forceinline__ float rocm_block_reduce_max(float val, int warp_id) {
__shared__ float staging[WARPS];
Expand Down
Loading