Skip to content

Commit 3959d01

Browse files
committed
feat(ops): implement CausalSoftmax operator with Moore backend
- add the Moore operator specialization - limit block size in the shared CUDA-style kernel for Moore
1 parent f17e37c commit 3959d01

2 files changed

Lines changed: 85 additions & 2 deletions

File tree

src/cuda/causal_softmax/kernel.h

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#ifndef INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
22
#define INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
33

4+
#include <algorithm>
45
#include <cstdint>
6+
#include <type_traits>
57

68
#include "base/causal_softmax.h"
79
#include "cuda/causal_softmax/kernel.cuh"
@@ -11,6 +13,49 @@
1113

1214
namespace infini::ops {
1315

16+
namespace causal_softmax::detail {
17+
18+
template <typename Backend, typename = void>
19+
struct MaxBlockSize : std::integral_constant<int, 2048> {};
20+
21+
template <typename Backend>
22+
struct MaxBlockSize<Backend, std::void_t<decltype(Backend::max_block_size)>>
23+
: std::integral_constant<int, Backend::max_block_size> {};
24+
25+
template <int max_block_size>
26+
struct SupportedBlockSizes;
27+
28+
template <>
29+
struct SupportedBlockSizes<2048> {
30+
using type = AllCudaBlockSizes;
31+
};
32+
33+
template <>
34+
struct SupportedBlockSizes<1024> {
35+
using type = List<128, 256, 512, 1024>;
36+
};
37+
38+
template <>
39+
struct SupportedBlockSizes<512> {
40+
using type = List<128, 256, 512>;
41+
};
42+
43+
template <>
44+
struct SupportedBlockSizes<256> {
45+
using type = List<128, 256>;
46+
};
47+
48+
template <>
49+
struct SupportedBlockSizes<128> {
50+
using type = List<128>;
51+
};
52+
53+
template <typename Backend>
54+
using SupportedBlockSizesFor =
55+
typename SupportedBlockSizes<MaxBlockSize<Backend>::value>::type;
56+
57+
} // namespace causal_softmax::detail
58+
1459
template <typename Backend>
1560
class CudaCausalSoftmax : public CausalSoftmax {
1661
public:
@@ -32,10 +77,13 @@ class CudaCausalSoftmax : public CausalSoftmax {
3277
std::abort();
3378
}
3479

35-
int block_size = GetOptimalBlockSize();
80+
constexpr int kMaxBlockSize =
81+
causal_softmax::detail::MaxBlockSize<Backend>::value;
82+
int block_size = std::min(GetOptimalBlockSize(),
83+
kMaxBlockSize);
3684

3785
DispatchFunc<ConcatType<List<DataType::kFloat32>, ReducedFloatTypes>,
38-
AllCudaBlockSizes>(
86+
causal_softmax::detail::SupportedBlockSizesFor<Backend>>(
3987
// TODO: Output dtype should use the one passed in during construction.
4088
{static_cast<int64_t>(out.dtype()), block_size},
4189
[&](auto list_tag) {

src/moore/causal_softmax/kernel.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#ifndef INFINI_OPS_MOORE_CAUSAL_SOFTMAX_KERNEL_H_
2+
#define INFINI_OPS_MOORE_CAUSAL_SOFTMAX_KERNEL_H_
3+
4+
// clang-format off
5+
#include <musa_runtime.h>
6+
// clang-format on
7+
8+
// clang-format off
9+
#include "moore/device_.h"
10+
// clang-format on
11+
12+
#include "cuda/causal_softmax/kernel.h"
13+
14+
namespace infini::ops {
15+
16+
namespace causal_softmax {
17+
18+
struct MooreBackend {
19+
using stream_t = musaStream_t;
20+
21+
static constexpr int max_block_size = 256;
22+
};
23+
24+
} // namespace causal_softmax
25+
26+
template <>
27+
class Operator<CausalSoftmax, Device::Type::kMoore>
28+
: public CudaCausalSoftmax<causal_softmax::MooreBackend> {
29+
public:
30+
using CudaCausalSoftmax<causal_softmax::MooreBackend>::CudaCausalSoftmax;
31+
};
32+
33+
} // namespace infini::ops
34+
35+
#endif

0 commit comments

Comments
 (0)