Skip to content

Commit a9e11e9

Browse files
committed
feat(ops): implement CausalSoftmax operator with Moore backend
- add the Moore operator specialization - limit block size in the shared CUDA-style kernel for Moore
1 parent 61fcdf7 commit a9e11e9

2 files changed

Lines changed: 72 additions & 11 deletions

File tree

src/cuda/causal_softmax/kernel.h

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#ifndef INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
22
#define INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
33

4+
#include <algorithm>
45
#include <cstdint>
6+
#include <type_traits>
57

68
#include "base/causal_softmax.h"
79
#include "common/cuda/kernel_commons.h"
@@ -11,6 +13,17 @@
1113

1214
namespace infini::ops {
1315

16+
namespace causal_softmax::detail {
17+
18+
template <typename Backend, typename = void>
19+
struct MaxBlockSize : std::integral_constant<int, CUDA_BLOCK_SIZE_2048> {};
20+
21+
template <typename Backend>
22+
struct MaxBlockSize<Backend, std::void_t<decltype(Backend::max_block_size)>>
23+
: std::integral_constant<int, Backend::max_block_size> {};
24+
25+
} // namespace causal_softmax::detail
26+
1427
template <typename Backend>
1528
class CudaCausalSoftmax : public CausalSoftmax {
1629
public:
@@ -32,7 +45,10 @@ class CudaCausalSoftmax : public CausalSoftmax {
3245
std::abort();
3346
}
3447

35-
int block_size = GetOptimalBlockSize();
48+
constexpr int kMaxBlockSize =
49+
causal_softmax::detail::MaxBlockSize<Backend>::value;
50+
int block_size = std::min(GetOptimalBlockSize(),
51+
kMaxBlockSize);
3652

3753
DispatchFunc<DataType::kFloat32, DataType::kFloat16, DataType::kBFloat16>(
3854
out.dtype(),
@@ -47,17 +63,31 @@ class CudaCausalSoftmax : public CausalSoftmax {
4763
total_seq_len_, stride_out_batch, stride_out_row, \
4864
stride_input_batch, stride_input_row);
4965

50-
if (block_size == CUDA_BLOCK_SIZE_2048) {
51-
LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_2048)
52-
} else if (block_size == CUDA_BLOCK_SIZE_1024) {
53-
LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_1024)
54-
} else if (block_size == CUDA_BLOCK_SIZE_512) {
55-
LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_512)
56-
} else if (block_size == CUDA_BLOCK_SIZE_256) {
57-
LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_256)
58-
} else {
59-
LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_128)
66+
if constexpr (kMaxBlockSize >= CUDA_BLOCK_SIZE_2048) {
67+
if (block_size == CUDA_BLOCK_SIZE_2048) {
68+
LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_2048)
69+
return;
70+
}
71+
}
72+
if constexpr (kMaxBlockSize >= CUDA_BLOCK_SIZE_1024) {
73+
if (block_size == CUDA_BLOCK_SIZE_1024) {
74+
LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_1024)
75+
return;
76+
}
77+
}
78+
if constexpr (kMaxBlockSize >= CUDA_BLOCK_SIZE_512) {
79+
if (block_size == CUDA_BLOCK_SIZE_512) {
80+
LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_512)
81+
return;
82+
}
83+
}
84+
if constexpr (kMaxBlockSize >= CUDA_BLOCK_SIZE_256) {
85+
if (block_size == CUDA_BLOCK_SIZE_256) {
86+
LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_256)
87+
return;
88+
}
6089
}
90+
LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_128)
6191

6292
#undef LAUNCH_CAUSAL_SOFTMAX_KERNEL
6393
},

src/moore/causal_softmax/kernel.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#ifndef INFINI_OPS_MOORE_CAUSAL_SOFTMAX_KERNEL_H_
2+
#define INFINI_OPS_MOORE_CAUSAL_SOFTMAX_KERNEL_H_
3+
4+
// clang-format off
5+
#include <musa_runtime.h>
6+
// clang-format on
7+
8+
#include "cuda/causal_softmax/kernel.h"
9+
10+
namespace infini::ops {
11+
12+
namespace causal_softmax {
13+
14+
struct MooreBackend {
15+
using stream_t = musaStream_t;
16+
17+
static constexpr int max_block_size = CUDA_BLOCK_SIZE_256;
18+
};
19+
20+
} // namespace causal_softmax
21+
22+
template <>
23+
class Operator<CausalSoftmax, Device::Type::kMoore>
24+
: public CudaCausalSoftmax<causal_softmax::MooreBackend> {
25+
public:
26+
using CudaCausalSoftmax<causal_softmax::MooreBackend>::CudaCausalSoftmax;
27+
};
28+
29+
} // namespace infini::ops
30+
31+
#endif

0 commit comments

Comments
 (0)