Skip to content

Commit 43809e7

Browse files
committed
feat(ops): implement CausalSoftmax operator with Hygon backend.
1 parent f57a5c6 commit 43809e7

3 files changed

Lines changed: 89 additions & 2 deletions

File tree

src/cuda/causal_softmax/kernel.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
22
#define INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
33

4+
#include <algorithm>
45
#include <cassert>
56
#include <cstdint>
67

@@ -31,10 +32,11 @@ class CudaCausalSoftmax : public CausalSoftmax {
3132

3233
assert(out.dtype() == input.dtype());
3334

34-
int block_size = Backend::GetOptimalBlockSize();
35+
constexpr int kMaxBlockSize = BackendMaxBlockSize<Backend>::value;
36+
int block_size = std::min(Backend::GetOptimalBlockSize(), kMaxBlockSize);
3537

3638
DispatchFunc<ConcatType<List<DataType::kFloat32>, ReducedFloatTypes>,
37-
AllCudaBlockSizes>(
39+
SupportedCudaBlockSizesType<BackendMaxBlockSize<Backend>::value>>(
3840
// TODO: Output dtype should use the one passed in during construction.
3941
{static_cast<int64_t>(out.dtype()), block_size},
4042
[&](auto list_tag) {

src/cuda/kernel_commons.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,54 @@
11
#ifndef INFINI_OPS_COMMON_CUDA_KERNEL_COMMONS_H_
22
#define INFINI_OPS_COMMON_CUDA_KERNEL_COMMONS_H_
33

4+
#include <type_traits>
5+
46
#include "caster.h"
57

68
namespace infini::ops {
79

810
using AllCudaBlockSizes = List<128, 256, 512, 1024, 2048>;
911

12+
template <typename Backend, typename = void>
13+
struct BackendMaxBlockSize : std::integral_constant<int, 2048> {};
14+
15+
template <typename Backend>
16+
struct BackendMaxBlockSize<Backend,
17+
std::void_t<decltype(Backend::max_block_size)>>
18+
: std::integral_constant<int, Backend::max_block_size> {};
19+
20+
template <int max_block_size>
21+
struct SupportedCudaBlockSizes;
22+
23+
template <>
24+
struct SupportedCudaBlockSizes<2048> {
25+
using type = AllCudaBlockSizes;
26+
};
27+
28+
template <>
29+
struct SupportedCudaBlockSizes<1024> {
30+
using type = List<128, 256, 512, 1024>;
31+
};
32+
33+
template <>
34+
struct SupportedCudaBlockSizes<512> {
35+
using type = List<128, 256, 512>;
36+
};
37+
38+
template <>
39+
struct SupportedCudaBlockSizes<256> {
40+
using type = List<128, 256>;
41+
};
42+
43+
template <>
44+
struct SupportedCudaBlockSizes<128> {
45+
using type = List<128>;
46+
};
47+
48+
template <int max_block_size>
49+
using SupportedCudaBlockSizesType =
50+
typename SupportedCudaBlockSizes<max_block_size>::type;
51+
1052
__forceinline__ __device__ __host__ size_t
1153
IndexToOffset(size_t flat_index, size_t ndim, const size_t* shape,
1254
const ptrdiff_t* strides) {

src/hygon/causal_softmax/kernel.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#ifndef INFINI_OPS_HYGON_CAUSAL_SOFTMAX_KERNEL_H_
2+
#define INFINI_OPS_HYGON_CAUSAL_SOFTMAX_KERNEL_H_
3+
4+
#include <utility>
5+
6+
// clang-format off
7+
#include <cuda_runtime.h>
8+
// clang-format on
9+
10+
// clang-format off
11+
#include "hygon/device_.h"
12+
// clang-format on
13+
14+
#include "cuda/causal_softmax/kernel.h"
15+
16+
namespace infini::ops {
17+
18+
namespace causal_softmax {
19+
20+
struct HygonBackend {
21+
using stream_t = cudaStream_t;
22+
23+
static constexpr Device::Type kDeviceType = Device::Type::kHygon;
24+
25+
static constexpr int max_block_size = 256;
26+
27+
static int GetOptimalBlockSize() {
28+
return ComputeOptimalBlockSize(QueryMaxThreadsPerBlock());
29+
}
30+
};
31+
32+
} // namespace causal_softmax
33+
34+
template <>
35+
class Operator<CausalSoftmax, Device::Type::kHygon>
36+
: public CudaCausalSoftmax<causal_softmax::HygonBackend> {
37+
public:
38+
using CudaCausalSoftmax<causal_softmax::HygonBackend>::CudaCausalSoftmax;
39+
};
40+
41+
} // namespace infini::ops
42+
43+
#endif

0 commit comments

Comments
 (0)