Skip to content

Commit fe102f7

Browse files
committed
feat(ops): implement CausalSoftmax operator with Hygon backend.
1 parent abde23a commit fe102f7

2 files changed

Lines changed: 53 additions & 1 deletion

File tree

src/cuda/causal_softmax/kernel.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#ifndef INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
22
#define INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
33

4+
#include <algorithm>
45
#include <cstdint>
6+
#include <type_traits>
57

68
#include "base/causal_softmax.h"
79
#include "cuda/causal_softmax/kernel.cuh"
@@ -11,6 +13,17 @@
1113

1214
namespace infini::ops {
1315

16+
namespace causal_softmax::detail {
17+
18+
template <typename Backend, typename = void>
19+
struct MaxBlockSize : std::integral_constant<int, 2048> {};
20+
21+
template <typename Backend>
22+
struct MaxBlockSize<Backend, std::void_t<decltype(Backend::max_block_size)>>
23+
: std::integral_constant<int, Backend::max_block_size> {};
24+
25+
} // namespace causal_softmax::detail
26+
1427
template <typename Backend>
1528
class CudaCausalSoftmax : public CausalSoftmax {
1629
public:
@@ -32,7 +45,9 @@ class CudaCausalSoftmax : public CausalSoftmax {
3245
std::abort();
3346
}
3447

35-
int block_size = GetOptimalBlockSize();
48+
constexpr int kMaxBlockSize =
49+
causal_softmax::detail::MaxBlockSize<Backend>::value;
50+
int block_size = std::min(GetOptimalBlockSize(), kMaxBlockSize);
3651

3752
DispatchFunc<ConcatType<List<DataType::kFloat32>, ReducedFloatTypes>,
3853
AllCudaBlockSizes>(

src/hygon/causal_softmax/kernel.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#ifndef INFINI_OPS_HYGON_CAUSAL_SOFTMAX_KERNEL_H_
2+
#define INFINI_OPS_HYGON_CAUSAL_SOFTMAX_KERNEL_H_
3+
4+
#include <utility>
5+
6+
// clang-format off
7+
#include <cuda_runtime.h>
8+
// clang-format on
9+
10+
// clang-format off
11+
#include "hygon/device_.h"
12+
// clang-format on
13+
14+
#include "cuda/causal_softmax/kernel.h"
15+
16+
namespace infini::ops {
17+
18+
namespace causal_softmax {
19+
20+
struct HygonBackend {
21+
using stream_t = cudaStream_t;
22+
23+
static constexpr int max_block_size = 256;
24+
};
25+
26+
} // namespace causal_softmax
27+
28+
template <>
29+
class Operator<CausalSoftmax, Device::Type::kHygon>
30+
: public CudaCausalSoftmax<causal_softmax::HygonBackend> {
31+
public:
32+
using CudaCausalSoftmax<causal_softmax::HygonBackend>::CudaCausalSoftmax;
33+
};
34+
35+
} // namespace infini::ops
36+
37+
#endif

0 commit comments

Comments
 (0)