11#ifndef INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
22#define INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
33
4+ #include < algorithm>
45#include < cstdint>
6+ #include < type_traits>
57
68#include " base/causal_softmax.h"
79#include " common/cuda/kernel_commons.h"
1113
1214namespace infini ::ops {
1315
16+ namespace causal_softmax ::detail {
17+
18+ template <typename Backend, typename = void >
19+ struct MaxBlockSize : std::integral_constant<int , CUDA_BLOCK_SIZE_2048> {};
20+
21+ template <typename Backend>
22+ struct MaxBlockSize <Backend, std::void_t <decltype (Backend::max_block_size)>>
23+ : std::integral_constant<int , Backend::max_block_size> {};
24+
25+ } // namespace causal_softmax::detail
26+
1427template <typename Backend>
1528class CudaCausalSoftmax : public CausalSoftmax {
1629 public:
@@ -32,7 +45,10 @@ class CudaCausalSoftmax : public CausalSoftmax {
3245 std::abort ();
3346 }
3447
35- int block_size = GetOptimalBlockSize ();
48+ constexpr int kMaxBlockSize =
49+ causal_softmax::detail::MaxBlockSize<Backend>::value;
50+ int block_size = std::min (GetOptimalBlockSize (),
51+ kMaxBlockSize );
3652
3753 DispatchFunc<DataType::kFloat32 , DataType::kFloat16 , DataType::kBFloat16 >(
3854 out.dtype (),
@@ -47,17 +63,31 @@ class CudaCausalSoftmax : public CausalSoftmax {
4763 total_seq_len_, stride_out_batch, stride_out_row, \
4864 stride_input_batch, stride_input_row);
4965
50- if (block_size == CUDA_BLOCK_SIZE_2048) {
51- LAUNCH_CAUSAL_SOFTMAX_KERNEL (CUDA_BLOCK_SIZE_2048)
52- } else if (block_size == CUDA_BLOCK_SIZE_1024) {
53- LAUNCH_CAUSAL_SOFTMAX_KERNEL (CUDA_BLOCK_SIZE_1024)
54- } else if (block_size == CUDA_BLOCK_SIZE_512) {
55- LAUNCH_CAUSAL_SOFTMAX_KERNEL (CUDA_BLOCK_SIZE_512)
56- } else if (block_size == CUDA_BLOCK_SIZE_256) {
57- LAUNCH_CAUSAL_SOFTMAX_KERNEL (CUDA_BLOCK_SIZE_256)
58- } else {
59- LAUNCH_CAUSAL_SOFTMAX_KERNEL (CUDA_BLOCK_SIZE_128)
66+ if constexpr (kMaxBlockSize >= CUDA_BLOCK_SIZE_2048) {
67+ if (block_size == CUDA_BLOCK_SIZE_2048) {
68+ LAUNCH_CAUSAL_SOFTMAX_KERNEL (CUDA_BLOCK_SIZE_2048)
69+ return ;
70+ }
71+ }
72+ if constexpr (kMaxBlockSize >= CUDA_BLOCK_SIZE_1024) {
73+ if (block_size == CUDA_BLOCK_SIZE_1024) {
74+ LAUNCH_CAUSAL_SOFTMAX_KERNEL (CUDA_BLOCK_SIZE_1024)
75+ return ;
76+ }
77+ }
78+ if constexpr (kMaxBlockSize >= CUDA_BLOCK_SIZE_512) {
79+ if (block_size == CUDA_BLOCK_SIZE_512) {
80+ LAUNCH_CAUSAL_SOFTMAX_KERNEL (CUDA_BLOCK_SIZE_512)
81+ return ;
82+ }
83+ }
84+ if constexpr (kMaxBlockSize >= CUDA_BLOCK_SIZE_256) {
85+ if (block_size == CUDA_BLOCK_SIZE_256) {
86+ LAUNCH_CAUSAL_SOFTMAX_KERNEL (CUDA_BLOCK_SIZE_256)
87+ return ;
88+ }
6089 }
90+ LAUNCH_CAUSAL_SOFTMAX_KERNEL (CUDA_BLOCK_SIZE_128)
6191
6292#undef LAUNCH_CAUSAL_SOFTMAX_KERNEL
6393 },
0 commit comments