Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,10 @@ class EpilogueMoeFusedFinalize {
auto N = get<1>(problem_shape_mnkl);
auto L = get<3>(problem_shape_mnkl);

auto mma_tile_m = tile_size<0>(tiled_mma);
auto mma_tile_n = tile_size<1>(tiled_mma);
auto epi_tile_m = size<0>(EpilogueTile{});
auto epi_tile_n = size<1>(EpilogueTile{});
constexpr auto mma_tile_m = decltype(tile_size<0>(tiled_mma)){};
constexpr auto mma_tile_n = decltype(tile_size<1>(tiled_mma)){};
constexpr auto epi_tile_m = size<0>(EpilogueTile{});
constexpr auto epi_tile_n = size<1>(EpilogueTile{});

CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M");
CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N");
Expand Down Expand Up @@ -248,16 +248,17 @@ class EpilogueMoeFusedFinalize {
Tensor tRS_rD = make_tensor<ElementAccumulator>(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N)

// Make a tiled copy vectorized along major direction of D
constexpr int TiledMmaThreads = decltype(cute::size(tiled_mma))::value;
auto tiled_s2r = [&]() {
if constexpr (cutlass::gemm::detail::is_k_major<StrideD>()) {
constexpr int NumThreadsMajor = epi_tile_n / AlignmentD;
constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
constexpr int NumThreadsMinor = TiledMmaThreads / NumThreadsMajor;
return make_tiled_copy(CopyAtomS2R{},
Layout<Shape<Int<NumThreadsMinor>, Int<NumThreadsMajor>>, Stride<Int<NumThreadsMajor>, _1>>{},
Layout<Shape<_1, Int<AlignmentD>>>{});
} else if constexpr (cutlass::gemm::detail::is_mn_major<StrideD>()) {
constexpr int NumThreadsMajor = epi_tile_m / AlignmentD;
constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
constexpr int NumThreadsMinor = TiledMmaThreads / NumThreadsMajor;
return make_tiled_copy(CopyAtomS2R{},
Layout<Shape<Int<NumThreadsMajor>, Int<NumThreadsMinor>>, Stride<_1, Int<NumThreadsMajor>>>{},
Layout<Shape<Int<AlignmentD>, _1>>{});
Expand All @@ -274,11 +275,11 @@ class EpilogueMoeFusedFinalize {
Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)

// Allocate intermediate registers for a single subtile
Tensor tSR_rD = make_tensor<ElementAccumulator>(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rD_final = make_tensor<ElementD>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rC = make_tensor<ElementC>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rBias = make_tensor<ElementBias>(tSR_gBias(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rScale = make_tensor<ElementScale>(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rD = make_tensor<ElementAccumulator>(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rD_final = make_tensor<ElementD>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rC = make_tensor<ElementC>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rBias = make_tensor<ElementBias>(shape(tSR_gBias(_, _, _, 0, 0))); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rScale = make_tensor<ElementScale>(shape(tSR_gScale(_, _, _, 0, 0))); // ((S2R,S2R_V),S2R_M,S2R_N)

// Make an identity coordinate tensor for predicating our output MN tile
Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD))));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -586,8 +586,12 @@ struct MoeFCGemm {
run_kernel<arch::Sm80>(params, shared_storage);
}
#else
static_assert(
false, "Invalid architecture being compiled. Only Ampere+ supported in weight-only quantization kernels.");
// Pre-Ampere device compile pass: the MoeFCGemm body is unsupported on these archs,
// but NVCC must still emit *some* body for each requested target. Runtime dispatch
// in MoeGemmRunner::dispatchToArch() never invokes this kernel when sm_ < 80, so a
// device-side trap is safe and lets the same .cu compile cleanly under mixed arch
// lists (e.g. 52;61;75;86;89;90 in packaging pipelines).
CUTLASS_NOT_IMPLEMENTED();
#endif
#else
CUTLASS_NOT_IMPLEMENTED();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ ReturnType construct_if_true(Args&&... args)
{
if constexpr (FLAG)
{
return ReturnType{std::forward<Args>(args)...};
// Use parenthesized aggregate init (C++20) instead of brace-init to avoid
// MSVC C2397 narrowing conversion errors (e.g. size_t -> FastDivmod(int)).
return ReturnType(std::forward<Args>(args)...);
}
else
{
Expand Down
27 changes: 17 additions & 10 deletions onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1113,9 +1113,11 @@ void QMoE::PrePackSwizzleBlockScales(const Tensor& tensor, cudaStream_t stream,
p_src = temp_src_gpu.get();
}

// QMoEBlockScaleInterleaveKernel writes every byte of the output buffer
// (the (batch, row, col) -> offset map is a bijection over
// [0, batch_size) x [0, rows_padded) x [0, cols_padded), and padded
// source positions are written as 0), so no explicit memset is required.
packed_buf = IAllocator::MakeUniquePtr<void>(alloc, dst_bytes, true);
// Zero-fill for padding regions (kernel only writes within bounds)
CUDA_CALL_THROW(cudaMemsetAsync(packed_buf.get(), 0, dst_bytes, stream));

int multi_processor_count = 0;
int device_id = 0;
Expand Down Expand Up @@ -1250,16 +1252,23 @@ void QMoE::PrePackComputeBias(const Tensor& tensor, cudaStream_t stream, Allocat
return;
}

bool is_fp16 = is_fp16_;
bool is_bf16 = !is_fp16_;

ORT_ENFORCE(shape.NumDimensions() == 3, "Expected 3D zeros for block-wise 4-bit");
ORT_ENFORCE(shape[0] > 0 && shape[1] > 0 && shape[2] > 0,
"4-bit block-wise zeros must have positive dimensions, got ", shape.ToString());
// packed_k_blocks is doubled to k_blocks below; constrain it to half of INT_MAX to keep the
// doubled value (and the int dims passed into LaunchQMoEScaledZP4BitBatched) within int range.
constexpr int64_t kMaxPackedKBlocks = std::numeric_limits<int>::max() / 2;
ORT_ENFORCE(shape[0] <= std::numeric_limits<int>::max() &&
shape[1] <= std::numeric_limits<int>::max() &&
shape[2] <= kMaxPackedKBlocks,
"4-bit block-wise zeros dimensions exceed CUDA launch int range, got ", shape.ToString());
const int experts = static_cast<int>(shape[0]);
const int n = static_cast<int>(shape[1]);
const int packed_k_blocks = static_cast<int>(shape[2]);
const int k_blocks = packed_k_blocks * 2;
// QMoE only supports FP16/BF16 inputs (is_fp16_ is set in the ctor), both of which are 2 bytes.
size_t output_count = static_cast<size_t>(experts) * static_cast<size_t>(k_blocks) * static_cast<size_t>(n);
size_t bytes = output_count * (is_fp16 || is_bf16 ? 2 : 4);
size_t bytes = output_count * sizeof(uint16_t);
packed_bias = IAllocator::MakeUniquePtr<void>(alloc, bytes, true);

const void* p_src_zp = tensor.DataRaw();
Expand All @@ -1272,20 +1281,18 @@ void QMoE::PrePackComputeBias(const Tensor& tensor, cudaStream_t stream, Allocat

const uint8_t* zp_ptr = static_cast<const uint8_t*>(p_src_zp);
constexpr float kDefaultZeroPoint4Bit = 8.0f;
if (is_fp16) {
if (is_fp16_) {
LaunchQMoEScaledZP4BitBatched(
zp_ptr,
static_cast<const half*>(packed_scale.get()),
static_cast<half*>(packed_bias.get()),
experts, n, k_blocks, kDefaultZeroPoint4Bit, stream);
} else if (is_bf16) {
} else {
LaunchQMoEScaledZP4BitBatched(
zp_ptr,
static_cast<const __nv_bfloat16*>(packed_scale.get()),
static_cast<__nv_bfloat16*>(packed_bias.get()),
experts, n, k_blocks, kDefaultZeroPoint4Bit, stream);
} else {
ORT_THROW("Unsupported type for 4-bit block-wise ZP prepack. Expected FP16/BF16.");
}
}
CUDA_CALL_THROW(cudaStreamSynchronize(stream));
Expand Down
36 changes: 7 additions & 29 deletions onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,20 @@
// Licensed under the MIT License.

#include "contrib_ops/cuda/moe/qmoe_kernels.h"
#include "core/common/narrow.h"
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/llm/moe_gemm/moe_kernels.h"
#include <cuda_bf16.h>
#include <cub/cub.cuh>
#include <algorithm>
#include <cfloat>
#include <limits>

namespace onnxruntime {
namespace contrib {
namespace cuda {

int Compute1DGridSize(int num_elements, int block_size) {
ORT_ENFORCE(num_elements >= 0, "CUDA launch element count must be non-negative, got ", num_elements);
ORT_ENFORCE(block_size > 0, "CUDA launch block size must be positive, got ", block_size);
int64_t grid_size = (static_cast<int64_t>(num_elements) + block_size - 1) / block_size;
ORT_ENFORCE(grid_size <= std::numeric_limits<int>::max(),
"CUDA launch grid size exceeds int range: ", grid_size);
return static_cast<int>(grid_size);
return (num_elements + block_size - 1) / block_size;
}

template <typename T>
Expand Down Expand Up @@ -698,11 +693,7 @@ void LaunchQMoEDequantizeFp4WeightsImpl(
cudaStream_t stream) {
int64_t total = static_cast<int64_t>(num_experts) * n * k;
constexpr int block = 256;
ORT_ENFORCE(total >= 0, "QMoEDequantizeFp4Weights: negative element count, got ", total);
int64_t grid_i64 = (total + block - 1) / block;
ORT_ENFORCE(grid_i64 <= std::numeric_limits<int>::max(),
"QMoEDequantizeFp4Weights: grid size exceeds int range: ", grid_i64);
int grid = static_cast<int>(grid_i64);
int grid = onnxruntime::narrow<int>((total + block - 1) / block);
QMoEDequantizeFp4WeightsKernel<<<grid, block, 0, stream>>>(
packed_weights, block_scales, global_scales, output, num_experts, n, k);
}
Expand Down Expand Up @@ -785,11 +776,7 @@ void LaunchQMoEDequantizeFp8WeightsImpl(
cudaStream_t stream) {
int64_t total = static_cast<int64_t>(num_experts) * n * k;
constexpr int block = 256;
ORT_ENFORCE(total >= 0, "QMoEDequantizeFp8Weights: negative element count, got ", total);
int64_t grid_i64 = (total + block - 1) / block;
ORT_ENFORCE(grid_i64 <= std::numeric_limits<int>::max(),
"QMoEDequantizeFp8Weights: grid size exceeds int range: ", grid_i64);
int grid = static_cast<int>(grid_i64);
int grid = onnxruntime::narrow<int>((total + block - 1) / block);
QMoEDequantizeFp8WeightsKernel<<<grid, block, 0, stream>>>(
weights, global_scales, output, num_experts, n, k);
}
Expand Down Expand Up @@ -862,16 +849,10 @@ void LaunchQMoERepackFP4ColToRow(
int64_t k,
int64_t n,
cudaStream_t stream) {
ORT_ENFORCE(experts > 0, "LaunchQMoERepackFP4ColToRow requires positive expert count, got ", experts);
ORT_ENFORCE(k > 0 && n > 0, "LaunchQMoERepackFP4ColToRow requires positive k and n, got k=", k, ", n=", n);
ORT_ENFORCE(k % 2 == 0 && n % 2 == 0,
"LaunchQMoERepackFP4ColToRow requires even k and n, got k=", k, ", n=", n);
const int64_t total = static_cast<int64_t>(experts) * n * (k / 2);
constexpr int kThreads = 256;
int64_t blocks = (total + kThreads - 1) / kThreads;
ORT_ENFORCE(blocks <= static_cast<int64_t>(std::numeric_limits<int>::max()),
"LaunchQMoERepackFP4ColToRow grid size exceeds int range");
QMoERepackFP4ColToRowKernel<<<static_cast<int>(blocks), kThreads, 0, stream>>>(
int blocks = onnxruntime::narrow<int>((total + kThreads - 1) / kThreads);
QMoERepackFP4ColToRowKernel<<<blocks, kThreads, 0, stream>>>(
input, output, experts, k, n);
}

Expand Down Expand Up @@ -901,10 +882,7 @@ __global__ void BatchedTransposeKernel(const T* __restrict__ input, T* __restric
void LaunchBatchedTranspose(cudaStream_t stream, const void* input, void* output, int batch, int rows, int cols, int element_size) {
int64_t total_elements = static_cast<int64_t>(batch) * rows * cols;
int threads = 256;
int64_t blocks_i64 = (total_elements + threads - 1) / threads;
ORT_ENFORCE(blocks_i64 <= std::numeric_limits<int>::max(),
"LaunchBatchedTranspose grid size exceeds int range: ", blocks_i64);
int blocks = static_cast<int>(blocks_i64);
int blocks = onnxruntime::narrow<int>((total_elements + threads - 1) / threads);

if (element_size == 1) {
BatchedTransposeKernel<uint8_t><<<blocks, threads, 0, stream>>>(static_cast<const uint8_t*>(input), static_cast<uint8_t*>(output), batch, rows, cols);
Expand Down
Loading