Skip to content

Commit 8683622

Browse files
committed
feat(ops): implement CausalSoftmax operator with Moore backend
- add the Moore operator specialization - limit block size in the shared CUDA-style kernel for Moore
1 parent f17e37c commit 8683622

3 files changed

Lines changed: 80 additions & 2 deletions

File tree

src/cuda/causal_softmax/kernel.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
22
#define INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
33

4+
#include <algorithm>
45
#include <cstdint>
56

67
#include "base/causal_softmax.h"
@@ -32,10 +33,12 @@ class CudaCausalSoftmax : public CausalSoftmax {
3233
std::abort();
3334
}
3435

35-
int block_size = GetOptimalBlockSize();
36+
constexpr int kMaxBlockSize = BackendMaxBlockSize<Backend>::value;
37+
int block_size = std::min(GetOptimalBlockSize(), kMaxBlockSize);
3638

3739
DispatchFunc<ConcatType<List<DataType::kFloat32>, ReducedFloatTypes>,
38-
AllCudaBlockSizes>(
40+
SupportedCudaBlockSizesType<
41+
BackendMaxBlockSize<Backend>::value>>(
3942
// TODO: Output dtype should use the one passed in during construction.
4043
{static_cast<int64_t>(out.dtype()), block_size},
4144
[&](auto list_tag) {

src/cuda/kernel_commons.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ using cuda_bfloat162 = __mt_bfloat162;
2727

2828
#include <cstdlib>
2929
#include <iostream>
30+
#include <type_traits>
3031
#include <vector>
3132

3233
#include "caster.h"
@@ -35,6 +36,45 @@ namespace infini::ops {
3536

3637
using AllCudaBlockSizes = List<128, 256, 512, 1024, 2048>;
3738

39+
template <typename Backend, typename = void>
40+
struct BackendMaxBlockSize : std::integral_constant<int, 2048> {};
41+
42+
template <typename Backend>
43+
struct BackendMaxBlockSize<Backend, std::void_t<decltype(Backend::max_block_size)>>
44+
: std::integral_constant<int, Backend::max_block_size> {};
45+
46+
template <int max_block_size>
47+
struct SupportedCudaBlockSizes;
48+
49+
template <>
50+
struct SupportedCudaBlockSizes<2048> {
51+
using type = AllCudaBlockSizes;
52+
};
53+
54+
template <>
55+
struct SupportedCudaBlockSizes<1024> {
56+
using type = List<128, 256, 512, 1024>;
57+
};
58+
59+
template <>
60+
struct SupportedCudaBlockSizes<512> {
61+
using type = List<128, 256, 512>;
62+
};
63+
64+
template <>
65+
struct SupportedCudaBlockSizes<256> {
66+
using type = List<128, 256>;
67+
};
68+
69+
template <>
70+
struct SupportedCudaBlockSizes<128> {
71+
using type = List<128>;
72+
};
73+
74+
template <int max_block_size>
75+
using SupportedCudaBlockSizesType =
76+
typename SupportedCudaBlockSizes<max_block_size>::type;
77+
3878
#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR)
3979
// Cache `cudaDeviceProp` per device, initialized once at first access.
4080
class DevicePropertyCache {

src/moore/causal_softmax/kernel.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#ifndef INFINI_OPS_MOORE_CAUSAL_SOFTMAX_KERNEL_H_
2+
#define INFINI_OPS_MOORE_CAUSAL_SOFTMAX_KERNEL_H_
3+
4+
// clang-format off
5+
#include <musa_runtime.h>
6+
// clang-format on
7+
8+
// clang-format off
9+
#include "moore/device_.h"
10+
// clang-format on
11+
12+
#include "cuda/causal_softmax/kernel.h"
13+
14+
namespace infini::ops {
15+
16+
namespace causal_softmax {
17+
18+
struct MooreBackend {
19+
using stream_t = musaStream_t;
20+
21+
static constexpr int max_block_size = 1024;
22+
};
23+
24+
} // namespace causal_softmax
25+
26+
template <>
27+
class Operator<CausalSoftmax, Device::Type::kMoore>
28+
: public CudaCausalSoftmax<causal_softmax::MooreBackend> {
29+
public:
30+
using CudaCausalSoftmax<causal_softmax::MooreBackend>::CudaCausalSoftmax;
31+
};
32+
33+
} // namespace infini::ops
34+
35+
#endif

0 commit comments

Comments
 (0)