Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ bool ck_tile_grouped_gemm(const NVTETensor* A,

// Currently the accumulate path is only supported on fp16
if (accumulate && is_8bit_float) {
NVTE_WARN("ck_tile_grouped_gemm: accumulate is currently unsupported on fp8");
return false;
}

Expand Down Expand Up @@ -94,8 +95,11 @@ bool ck_tile_grouped_gemm(const NVTETensor* A,
}
}

const auto a_dtype = convertNVTETensorCheck(A_use[0])->dtype();
const auto b_dtype = convertNVTETensorCheck(B_use[0])->dtype();
const auto& A0_data = use_a_colwise_data ? A0_te->columnwise_data : A0_te->data;
const auto& B0_data = use_b_colwise_data ? B0_te->columnwise_data : B0_te->data;

const auto a_dtype = A0_data.dtype;
const auto b_dtype = B0_data.dtype;

Tensor* D0_te = convertNVTETensorCheck(D[0]);
const auto d_dtype = D0_te->dtype();
Expand Down Expand Up @@ -156,6 +160,7 @@ bool ck_tile_grouped_gemm(const NVTETensor* A,
B_use,
D,
static_cast<int>(n),
static_cast<int>(kA),
group_num,
transA_use,
transB_use,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ struct GroupedGemmRunContext {
const NVTETensor* B = nullptr;
NVTETensor* D = nullptr;
int64_t N = 0;
int64_t K = 0;

int group_num = 0;
bool transA = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ class GroupedGemmRunner : public RunnerInterface {
}
};

#define MAKE_RUNNER(TileCfg_) \
#define MAKE_FP16_RUNNER(TileCfg_) \
TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.accumulate, accum_option, { \
using Runner = GroupedGemmRunner<AType, \
BType, \
Expand Down Expand Up @@ -231,11 +231,11 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,
using CType = typename TETypeToCKType<d_te_type>::type;

if (ctx.N % 256 == 0) {
MAKE_RUNNER(TileCfg_256x256x64);
MAKE_FP16_RUNNER(TileCfg_256x256x64);
} else if (ctx.N % 128 == 0) {
MAKE_RUNNER(TileCfg_256x128x64);
MAKE_FP16_RUNNER(TileCfg_256x128x64);
} else {
MAKE_RUNNER(TileCfg_256x128x64_padding);
MAKE_FP16_RUNNER(TileCfg_256x128x64_padding);
}
});
});
Expand All @@ -249,7 +249,7 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,
return runner->run(s, ctx);
}

#undef MAKE_RUNNER
#undef MAKE_FP16_RUNNER

} // namespace grouped_gemm
} // namespace transformer_engine
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ enum class GPUArch {
UNKNOWN
};

struct TileCfg_128x128x128_16x16x128_2x2x1 {
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
struct TileCfg_256x256x128_16x16x128_2x2x1 {
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128;

static constexpr ck_tile::index_t M_Warp = 2;
Expand All @@ -45,13 +45,41 @@ struct TileCfg_128x128x128_16x16x128_2x2x1 {
static constexpr ck_tile::index_t TilePartitionerM01 = 8;
};

struct TileCfg_128x128x128_16x16x128_2x2x1
: TileCfg_256x256x128_16x16x128_2x2x1 {
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
};

struct TileCfg_256x256x128_16x16x128_2x2x1_kpad
: TileCfg_256x256x128_16x16x128_2x2x1 {
static constexpr bool kPadK = true;
};

struct TileCfg_128x128x128_16x16x128_2x2x1_kpad
: TileCfg_128x128x128_16x16x128_2x2x1 {
static constexpr bool kPadK = true;
};

struct TileCfg_128x128x128_16x16x128_2x2x1_npad
: TileCfg_128x128x128_16x16x128_2x2x1 {
static constexpr bool kPadN = true;
};

struct TileCfg_128x128x128_16x16x128_2x2x1_nkpad
: TileCfg_128x128x128_16x16x128_2x2x1 {
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
};

// gfx950 device compilation cannot instantiate the literal 32x32x16 FP8 tile
// configuration due to an unsupported warp GEMM dispatcher configuration.
// See: ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp for supported variants.
//
// To preserve the existing type name in shared template code, this struct
// inherits from the gfx950-safe 16x16x128 configuration in the gfx950 device
// compilation path, effectively reusing those parameters without redefining them.
// inherits from the gfx950-safe 128x128x128 16x16x128 configuration in the
// gfx950 device compilation path, effectively reusing those parameters without
// redefining them.
//
// In all other compilation paths, the struct overrides the relevant fields to
// provide the intended 32x32x16 configuration.
Expand Down Expand Up @@ -261,7 +289,9 @@ class QuantGroupedGemmRunner : public RunnerInterface {
if (descs.empty()) {
return false;
}
return launch_grouped_gemm_kernel<Kernel>(descs, ctx, stream_cfg);

const bool launched = launch_grouped_gemm_kernel<Kernel>(descs, ctx, stream_cfg);
return launched;
}
};

Expand Down Expand Up @@ -290,6 +320,78 @@ struct FP8TileCfg<GPUArch::GFX950> {
using type = TileCfg_128x128x128_16x16x128_2x2x1;
};

struct FP8GroupedShapeAlignment {
bool all_n_256_aligned = true;
bool all_n_128_aligned = true;
bool all_k_128_aligned = true;
};

static FP8GroupedShapeAlignment get_fp8_grouped_shape_alignment(
const GroupedGemmRunContext& ctx) {
FP8GroupedShapeAlignment alignment;

for (int i = 0; i < ctx.group_num; ++i) {
const transformer_engine::Tensor* const A_te =
transformer_engine::convertNVTETensorCheck(ctx.A[i]);
const transformer_engine::Tensor* const B_te =
transformer_engine::convertNVTETensorCheck(ctx.B[i]);

int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0;

if (ctx.use_a_columnwise_data) {
if (!get_columnwise_storage_2d_dims(A_te->columnwise_data, Ad0, Ad1)) {
NVTE_ERROR("ck_tile_grouped_gemm: expected 2D columnwise_data for A in group ", i);
}
} else {
if (!get_flat_2d_dims(*A_te, Ad0, Ad1)) {
NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for normalized A in group ", i);
}
}

if (ctx.use_b_columnwise_data) {
if (!get_columnwise_storage_2d_dims(B_te->columnwise_data, Bd0, Bd1)) {
NVTE_ERROR("ck_tile_grouped_gemm: expected 2D columnwise_data for B in group ", i);
}
} else {
if (!get_flat_2d_dims(*B_te, Bd0, Bd1)) {
NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for normalized B in group ", i);
}
}

const int64_t K = ctx.transA ? Ad0 : Ad1;
const int64_t N = ctx.transB ? Bd0 : Bd1;

if (N % 256 != 0) {
alignment.all_n_256_aligned = false;
}
if (N % 128 != 0) {
alignment.all_n_128_aligned = false;
}
if (K % 128 != 0) {
alignment.all_k_128_aligned = false;
}

if (!alignment.all_n_256_aligned &&
!alignment.all_n_128_aligned &&
!alignment.all_k_128_aligned) {
break;
}
}

return alignment;
}

#define MAKE_FP8_RUNNER(TileCfg_) \
using Runner = QuantGroupedGemmRunner<AType, \
BType, \
CType, \
ALayout, \
BLayout, \
CTypeLayout, \
TileCfg_, \
ck_tile::memory_operation_enum::set>; \
runner = std::make_unique<Runner>()

template <GPUArch Arch>
static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype,
DType b_dtype,
Expand All @@ -299,33 +401,55 @@ static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype,
std::unique_ptr<RunnerInterface> runner = nullptr;

using CTypeLayout = RowMajor;
using TileCfg = typename FP8TileCfg<Arch>::type;

TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, {
using ALayout = std::conditional_t<kTransA, ColMajor, RowMajor>;

TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transB, kTransB, {
using BLayout = std::conditional_t<kTransB, ColMajor, RowMajor>;

TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, {
using AType = typename TETypeToCKType<a_te_type>::type;

TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_te_type, {
using BType = typename TETypeToCKType<b_te_type>::type;

TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, {
using CType = typename TETypeToCKType<d_te_type>::type;
using Runner = QuantGroupedGemmRunner<AType,
BType,
CType,
ALayout,
BLayout,
CTypeLayout,
TileCfg,
ck_tile::memory_operation_enum::set>;
runner = std::make_unique<Runner>();
});
});

// FP8 grouped GEMM is only compiled for CK's preferred NT presentation:
// transA=false, transB=true
// which maps to:
// ALayout=RowMajor, BLayout=ColMajor.
//
// The caller is responsible for rewriting other FP8 layouts into this form
// using columnwise_data when needed. Reject anything that did not normalize
// successfully so we do not instantiate unreachable/unsupported layout variants.
if (ctx.transA || !ctx.transB) {
return false;
}

using ALayout = RowMajor;
using BLayout = ColMajor;

TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, {
using AType = typename TETypeToCKType<a_te_type>::type;

TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_te_type, {
using BType = typename TETypeToCKType<b_te_type>::type;

TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, {
using CType = typename TETypeToCKType<d_te_type>::type;

if constexpr (Arch == GPUArch::GFX950) {
const auto alignment = get_fp8_grouped_shape_alignment(ctx);

if (alignment.all_n_256_aligned) {
if (alignment.all_k_128_aligned) {
MAKE_FP8_RUNNER(TileCfg_256x256x128_16x16x128_2x2x1);
} else {
MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1_kpad);
}
} else if (alignment.all_n_128_aligned) {
if (alignment.all_k_128_aligned) {
MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1);
} else {
MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1_kpad);
}
} else if (alignment.all_k_128_aligned) {
MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1_npad);
} else {
MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1_nkpad);
}
} else {
using TileCfg = typename FP8TileCfg<Arch>::type;
MAKE_FP8_RUNNER(TileCfg);
}
});
});
});
Expand All @@ -334,9 +458,12 @@ static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype,
return false;
}

return runner->run(s, ctx);
const bool ok = runner->run(s, ctx);
return ok;
}

#undef MAKE_FP8_RUNNER

bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype,
DType b_dtype,
DType d_dtype,
Expand Down
15 changes: 13 additions & 2 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1160,9 +1160,20 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]);
auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]);
#ifdef __HIP_PLATFORM_AMD__
auto A_dt = inputA->data.dtype;
auto B_dt = inputB->data.dtype;
auto effective_dtype = [](const transformer_engine::Tensor* t) {
if (is_fp8_dtype(t->data.dtype)) {
return t->data.dtype;
}
if (t->has_columnwise_data() && is_fp8_dtype(t->columnwise_data.dtype)) {
return t->columnwise_data.dtype;
}
return t->data.dtype;
};

auto A_dt = effective_dtype(inputA);
auto B_dt = effective_dtype(inputB);
auto D_dt = OutputD->data.dtype;

return (
(is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt))
) ||
Expand Down
Loading