11#ifndef INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
22#define INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
33
4+ #include < algorithm>
45#include < cstdint>
6+ #include < type_traits>
57
68#include " base/causal_softmax.h"
79#include " cuda/causal_softmax/kernel.cuh"
1113
1214namespace infini ::ops {
1315
16+ namespace causal_softmax ::detail {
17+
18+ template <typename Backend, typename = void >
19+ struct MaxBlockSize : std::integral_constant<int , 2048 > {};
20+
21+ template <typename Backend>
22+ struct MaxBlockSize <Backend, std::void_t <decltype (Backend::max_block_size)>>
23+ : std::integral_constant<int , Backend::max_block_size> {};
24+
25+ template <int max_block_size>
26+ struct SupportedBlockSizes ;
27+
28+ template <>
29+ struct SupportedBlockSizes <2048 > {
30+ using type = AllCudaBlockSizes;
31+ };
32+
33+ template <>
34+ struct SupportedBlockSizes <1024 > {
35+ using type = List<128 , 256 , 512 , 1024 >;
36+ };
37+
38+ template <>
39+ struct SupportedBlockSizes <512 > {
40+ using type = List<128 , 256 , 512 >;
41+ };
42+
43+ template <>
44+ struct SupportedBlockSizes <256 > {
45+ using type = List<128 , 256 >;
46+ };
47+
48+ template <>
49+ struct SupportedBlockSizes <128 > {
50+ using type = List<128 >;
51+ };
52+
53+ template <typename Backend>
54+ using SupportedBlockSizesFor =
55+ typename SupportedBlockSizes<MaxBlockSize<Backend>::value>::type;
56+
57+ } // namespace causal_softmax::detail
58+
1459template <typename Backend>
1560class CudaCausalSoftmax : public CausalSoftmax {
1661 public:
@@ -32,10 +77,13 @@ class CudaCausalSoftmax : public CausalSoftmax {
3277 std::abort ();
3378 }
3479
35- int block_size = GetOptimalBlockSize ();
80+ constexpr int kMaxBlockSize =
81+ causal_softmax::detail::MaxBlockSize<Backend>::value;
82+ int block_size = std::min (GetOptimalBlockSize (),
83+ kMaxBlockSize );
3684
3785 DispatchFunc<ConcatType<List<DataType::kFloat32 >, ReducedFloatTypes>,
38- AllCudaBlockSizes >(
86+ causal_softmax::detail::SupportedBlockSizesFor<Backend> >(
3987 // TODO: Output dtype should use the one passed in during construction.
4088 {static_cast <int64_t >(out.dtype ()), block_size},
4189 [&](auto list_tag) {
0 commit comments