File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 11#ifndef INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
22#define INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
33
4+ #include < algorithm>
45#include < cstdint>
6+ #include < type_traits>
57
68#include " base/causal_softmax.h"
79#include " cuda/causal_softmax/kernel.cuh"
1113
1214namespace infini ::ops {
1315
16+ namespace causal_softmax ::detail {
17+
18+ template <typename Backend, typename = void >
19+ struct MaxBlockSize : std::integral_constant<int , 2048 > {};
20+
21+ template <typename Backend>
22+ struct MaxBlockSize <Backend, std::void_t <decltype (Backend::max_block_size)>>
23+ : std::integral_constant<int , Backend::max_block_size> {};
24+
25+ } // namespace causal_softmax::detail
26+
1427template <typename Backend>
1528class CudaCausalSoftmax : public CausalSoftmax {
1629 public:
@@ -32,7 +45,9 @@ class CudaCausalSoftmax : public CausalSoftmax {
3245 std::abort ();
3346 }
3447
35- int block_size = GetOptimalBlockSize ();
48+ constexpr int kMaxBlockSize =
49+ causal_softmax::detail::MaxBlockSize<Backend>::value;
50+ int block_size = std::min (GetOptimalBlockSize (), kMaxBlockSize );
3651
3752 DispatchFunc<ConcatType<List<DataType::kFloat32 >, ReducedFloatTypes>,
3853 AllCudaBlockSizes>(
Original file line number Diff line number Diff line change 1+ #ifndef INFINI_OPS_HYGON_CAUSAL_SOFTMAX_KERNEL_H_
2+ #define INFINI_OPS_HYGON_CAUSAL_SOFTMAX_KERNEL_H_
3+
4+ #include < utility>
5+
6+ // clang-format off
7+ #include < cuda_runtime.h>
8+ // clang-format on
9+
10+ // clang-format off
11+ #include " hygon/device_.h"
12+ // clang-format on
13+
14+ #include " cuda/causal_softmax/kernel.h"
15+
16+ namespace infini ::ops {
17+
18+ namespace causal_softmax {
19+
20+ struct HygonBackend {
21+ using stream_t = cudaStream_t;
22+
23+ static constexpr int max_block_size = 256 ;
24+ };
25+
26+ } // namespace causal_softmax
27+
28+ template <>
29+ class Operator <CausalSoftmax, Device::Type::kHygon >
30+ : public CudaCausalSoftmax<causal_softmax::HygonBackend> {
31+ public:
32+ using CudaCausalSoftmax<causal_softmax::HygonBackend>::CudaCausalSoftmax;
33+ };
34+
35+ } // namespace infini::ops
36+
37+ #endif
You can’t perform that action at this time.
0 commit comments