diff --git a/src/cuda/causal_softmax/kernel.h b/src/cuda/causal_softmax/kernel.h index 3dce77d..cdf6a4f 100644 --- a/src/cuda/causal_softmax/kernel.h +++ b/src/cuda/causal_softmax/kernel.h @@ -1,6 +1,7 @@ #ifndef INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_ #define INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_ +#include #include #include "base/causal_softmax.h" @@ -32,10 +33,12 @@ class CudaCausalSoftmax : public CausalSoftmax { std::abort(); } - int block_size = GetOptimalBlockSize(); + constexpr int kMaxBlockSize = BackendMaxBlockSize::value; + int block_size = std::min(GetOptimalBlockSize(), kMaxBlockSize); DispatchFunc, ReducedFloatTypes>, - AllCudaBlockSizes>( + SupportedCudaBlockSizesType< + BackendMaxBlockSize::value>>( // TODO: Output dtype should use the one passed in during construction. {static_cast(out.dtype()), block_size}, [&](auto list_tag) { diff --git a/src/cuda/kernel_commons.h b/src/cuda/kernel_commons.h index 6c987c7..8460d55 100644 --- a/src/cuda/kernel_commons.h +++ b/src/cuda/kernel_commons.h @@ -27,6 +27,7 @@ using cuda_bfloat162 = __mt_bfloat162; #include #include +#include #include #include "caster.h" @@ -35,6 +36,45 @@ namespace infini::ops { using AllCudaBlockSizes = List<128, 256, 512, 1024, 2048>; +template +struct BackendMaxBlockSize : std::integral_constant {}; + +template +struct BackendMaxBlockSize> + : std::integral_constant {}; + +template +struct SupportedCudaBlockSizes; + +template <> +struct SupportedCudaBlockSizes<2048> { + using type = AllCudaBlockSizes; +}; + +template <> +struct SupportedCudaBlockSizes<1024> { + using type = List<128, 256, 512, 1024>; +}; + +template <> +struct SupportedCudaBlockSizes<512> { + using type = List<128, 256, 512>; +}; + +template <> +struct SupportedCudaBlockSizes<256> { + using type = List<128, 256>; +}; + +template <> +struct SupportedCudaBlockSizes<128> { + using type = List<128>; +}; + +template +using SupportedCudaBlockSizesType = + typename SupportedCudaBlockSizes::type; + #if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) // Cache `cudaDeviceProp` per device, initialized once at first access. class DevicePropertyCache { diff --git a/src/moore/causal_softmax/kernel.h b/src/moore/causal_softmax/kernel.h new file mode 100644 index 0000000..71699fb --- /dev/null +++ b/src/moore/causal_softmax/kernel.h @@ -0,0 +1,35 @@ +#ifndef INFINI_OPS_MOORE_CAUSAL_SOFTMAX_KERNEL_H_ +#define INFINI_OPS_MOORE_CAUSAL_SOFTMAX_KERNEL_H_ + +// clang-format off +#include +// clang-format on + +// clang-format off +#include "moore/device_.h" +// clang-format on + +#include "cuda/causal_softmax/kernel.h" + +namespace infini::ops { + +namespace causal_softmax { + +struct MooreBackend { + using stream_t = musaStream_t; + + static constexpr int max_block_size = 1024; +}; + +} // namespace causal_softmax + +template <> +class Operator + : public CudaCausalSoftmax { + public: + using CudaCausalSoftmax::CudaCausalSoftmax; +}; + +} // namespace infini::ops + +#endif