Skip to content

Commit f57a5c6

Browse files
committed
feat(hygon-gemm): add Hygon backend support for Gemm
- add a Hygon `Gemm` backend on top of the shared CUDA BLAS path - use DTK-friendly compute and algo settings for fp32/fp16 gemm - fall back to `cublasGemmEx` for single-batch Hygon gemm to avoid DTK crashes - release Hygon cublas handles after each call and re-enable the `gemm` example - verified with `pip install -e .[dev]`, `pytest tests/test_gemm.py -k cuda`, and `pytest tests/test_gemm.py`
1 parent 5e5098b commit f57a5c6

3 files changed

Lines changed: 157 additions & 4 deletions

File tree

examples/CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,6 @@ file(GLOB_RECURSE EXAMPLE_SOURCES CONFIGURE_DEPENDS "*.cc")
22

33
# Iterate through each file and create an executable.
44
foreach(source_file ${EXAMPLE_SOURCES})
5-
if(WITH_HYGON AND source_file MATCHES "/gemm\\.cc$")
6-
continue()
7-
endif()
8-
95
get_filename_component(example_name ${source_file} NAME_WE)
106

117
add_executable(${example_name} ${source_file})

examples/gemm/gemm.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
#if WITH_ILUVATAR
1212
#include "iluvatar/gemm/cublas.h"
1313
#endif
14+
#if WITH_HYGON
15+
#include "hygon/gemm/cublas.h"
16+
#endif
1417
#if WITH_METAX
1518
#include "metax/gemm/mcblas.h"
1619
#endif

src/hygon/gemm/cublas.h

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
#ifndef INFINI_OPS_HYGON_GEMM_CUBLAS_H_
2+
#define INFINI_OPS_HYGON_GEMM_CUBLAS_H_
3+
4+
#include <utility>
5+
6+
// clang-format off
7+
#include "cublas_v2.h"
8+
// clang-format on
9+
10+
#include "cuda/gemm/blas.h"
11+
12+
namespace infini::ops {
13+
14+
namespace gemm {
15+
16+
struct HygonBackend {
17+
using blasHandle_t = cublasHandle_t;
18+
19+
using stream_t = cudaStream_t;
20+
21+
static constexpr auto BLAS_OP_N = CUBLAS_OP_N;
22+
23+
static constexpr auto BLAS_OP_T = CUBLAS_OP_T;
24+
25+
static constexpr auto R_16F = CUDA_R_16F;
26+
27+
static constexpr auto R_16BF = CUDA_R_16BF;
28+
29+
static constexpr auto R_32F = CUDA_R_32F;
30+
31+
static constexpr auto BLAS_COMPUTE_32F = CUBLAS_COMPUTE_32F;
32+
33+
// DTK exposes the TF32 enum for compatibility, but BW/GFX9-class Hygon
34+
// devices do not provide a working TF32 GEMM fast path.
35+
static constexpr auto BLAS_COMPUTE_32F_FAST_TF32 = CUBLAS_COMPUTE_32F;
36+
37+
static constexpr auto BLAS_GEMM_DEFAULT = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
38+
39+
static constexpr auto blasCreate = cublasCreate;
40+
41+
static constexpr auto blasSetStream = cublasSetStream;
42+
43+
static constexpr auto blasDestroy = cublasDestroy;
44+
45+
static constexpr auto blasGemmEx = [](auto&&... args) {
46+
return cublasGemmEx(std::forward<decltype(args)>(args)...);
47+
};
48+
49+
static constexpr auto blasGemmStridedBatchedEx = [](auto&&... args) {
50+
return cublasGemmStridedBatchedEx(std::forward<decltype(args)>(args)...);
51+
};
52+
53+
static auto GetDataType(DataType dtype) {
54+
if (dtype == DataType::kFloat16) return R_16F;
55+
if (dtype == DataType::kBFloat16) return R_16BF;
56+
return R_32F;
57+
}
58+
59+
static auto GetComputeType(DataType dtype) {
60+
if (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16)
61+
return BLAS_COMPUTE_32F;
62+
return BLAS_COMPUTE_32F_FAST_TF32;
63+
}
64+
};
65+
66+
} // namespace gemm
67+
68+
template <>
69+
class Operator<Gemm, Device::Type::kHygon> : public Blas<gemm::HygonBackend> {
70+
public:
71+
using Blas<gemm::HygonBackend>::Blas;
72+
73+
void operator()(const Tensor a, const Tensor b, std::optional<float> alpha,
74+
std::optional<float> beta, std::optional<int> trans_a,
75+
std::optional<int> trans_b, Tensor c) const override {
76+
const bool a_is_col_major = a.stride(-1) == 1;
77+
const bool b_is_col_major = b.stride(-1) == 1;
78+
const bool swap_a_and_b = c.stride(-1) == 1;
79+
80+
auto get_op_a = [&](int trans_a_value, int trans_b_value) {
81+
if (swap_a_and_b) {
82+
return (b_is_col_major == trans_b_value) ? gemm::HygonBackend::BLAS_OP_T
83+
: gemm::HygonBackend::BLAS_OP_N;
84+
}
85+
return (a_is_col_major != trans_a_value) ? gemm::HygonBackend::BLAS_OP_T
86+
: gemm::HygonBackend::BLAS_OP_N;
87+
};
88+
89+
auto get_op_b = [&](int trans_a_value, int trans_b_value) {
90+
if (swap_a_and_b) {
91+
return (a_is_col_major == trans_a_value) ? gemm::HygonBackend::BLAS_OP_T
92+
: gemm::HygonBackend::BLAS_OP_N;
93+
}
94+
return (b_is_col_major != trans_b_value) ? gemm::HygonBackend::BLAS_OP_T
95+
: gemm::HygonBackend::BLAS_OP_N;
96+
};
97+
98+
gemm::HygonBackend::blasHandle_t handle{};
99+
gemm::HygonBackend::blasCreate(&handle);
100+
gemm::HygonBackend::blasSetStream(
101+
handle, static_cast<gemm::HygonBackend::stream_t>(this->stream_));
102+
103+
const auto& alpha_value{alpha.value_or(this->alpha_)};
104+
const auto& beta_value{beta.value_or(this->beta_)};
105+
106+
const auto& trans_a_value{trans_a.value_or(this->trans_a_)};
107+
const auto& trans_b_value{trans_b.value_or(this->trans_b_)};
108+
auto op_a{get_op_a(trans_a_value, trans_b_value)};
109+
auto op_b{get_op_b(trans_a_value, trans_b_value)};
110+
const void* alpha_ptr{this->GetAlphaPtr(alpha_value, c.dtype())};
111+
const void* beta_ptr{this->GetBetaPtr(beta_value, c.dtype())};
112+
113+
if (this->batch_count_ == 1) {
114+
gemm::HygonBackend::blasGemmEx(
115+
handle, op_a, op_b, swap_a_and_b ? this->n_ : this->m_,
116+
swap_a_and_b ? this->m_ : this->n_, this->k_, alpha_ptr,
117+
swap_a_and_b ? b.data() : a.data(),
118+
gemm::HygonBackend::GetDataType(swap_a_and_b ? b.dtype()
119+
: a.dtype()),
120+
swap_a_and_b ? this->ldb_ : this->lda_,
121+
swap_a_and_b ? a.data() : b.data(),
122+
gemm::HygonBackend::GetDataType(swap_a_and_b ? a.dtype()
123+
: b.dtype()),
124+
swap_a_and_b ? this->lda_ : this->ldb_, beta_ptr, c.data(),
125+
gemm::HygonBackend::GetDataType(c.dtype()), this->ldc_,
126+
gemm::HygonBackend::GetComputeType(c.dtype()),
127+
gemm::HygonBackend::BLAS_GEMM_DEFAULT);
128+
} else {
129+
gemm::HygonBackend::blasGemmStridedBatchedEx(
130+
handle, op_a, op_b, swap_a_and_b ? this->n_ : this->m_,
131+
swap_a_and_b ? this->m_ : this->n_, this->k_, alpha_ptr,
132+
swap_a_and_b ? b.data() : a.data(),
133+
gemm::HygonBackend::GetDataType(swap_a_and_b ? b.dtype()
134+
: a.dtype()),
135+
swap_a_and_b ? this->ldb_ : this->lda_,
136+
swap_a_and_b ? this->batch_stride_b_ : this->batch_stride_a_,
137+
swap_a_and_b ? a.data() : b.data(),
138+
gemm::HygonBackend::GetDataType(swap_a_and_b ? a.dtype()
139+
: b.dtype()),
140+
swap_a_and_b ? this->lda_ : this->ldb_,
141+
swap_a_and_b ? this->batch_stride_a_ : this->batch_stride_b_,
142+
beta_ptr, c.data(), gemm::HygonBackend::GetDataType(c.dtype()),
143+
this->ldc_, this->batch_stride_c_, this->batch_count_,
144+
gemm::HygonBackend::GetComputeType(c.dtype()),
145+
gemm::HygonBackend::BLAS_GEMM_DEFAULT);
146+
}
147+
148+
gemm::HygonBackend::blasDestroy(handle);
149+
}
150+
};
151+
152+
} // namespace infini::ops
153+
154+
#endif

0 commit comments

Comments
 (0)