Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/cuda/causal_softmax/kernel.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
#define INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_

#include <algorithm>
#include <cstdint>

#include "base/causal_softmax.h"
Expand Down Expand Up @@ -32,10 +33,12 @@ class CudaCausalSoftmax : public CausalSoftmax {
std::abort();
}

int block_size = GetOptimalBlockSize();
constexpr int kMaxBlockSize = BackendMaxBlockSize<Backend>::value;
int block_size = std::min(GetOptimalBlockSize(), kMaxBlockSize);

DispatchFunc<ConcatType<List<DataType::kFloat32>, ReducedFloatTypes>,
AllCudaBlockSizes>(
SupportedCudaBlockSizesType<
BackendMaxBlockSize<Backend>::value>>(
// TODO: Output dtype should use the one passed in during construction.
{static_cast<int64_t>(out.dtype()), block_size},
[&](auto list_tag) {
Expand Down
40 changes: 40 additions & 0 deletions src/cuda/kernel_commons.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ using cuda_bfloat162 = __mt_bfloat162;

#include <cstdlib>
#include <iostream>
#include <type_traits>
#include <vector>

#include "caster.h"
Expand All @@ -35,6 +36,45 @@ namespace infini::ops {

using AllCudaBlockSizes = List<128, 256, 512, 1024, 2048>;

template <typename Backend, typename = void>
struct BackendMaxBlockSize : std::integral_constant<int, 2048> {};

template <typename Backend>
struct BackendMaxBlockSize<Backend, std::void_t<decltype(Backend::max_block_size)>>
: std::integral_constant<int, Backend::max_block_size> {};

template <int max_block_size>
struct SupportedCudaBlockSizes;

template <>
struct SupportedCudaBlockSizes<2048> {
using type = AllCudaBlockSizes;
};

template <>
struct SupportedCudaBlockSizes<1024> {
using type = List<128, 256, 512, 1024>;
};

template <>
struct SupportedCudaBlockSizes<512> {
using type = List<128, 256, 512>;
};

template <>
struct SupportedCudaBlockSizes<256> {
using type = List<128, 256>;
};

template <>
struct SupportedCudaBlockSizes<128> {
using type = List<128>;
};

template <int max_block_size>
using SupportedCudaBlockSizesType =
typename SupportedCudaBlockSizes<max_block_size>::type;

#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR)
// Cache `cudaDeviceProp` per device, initialized once at first access.
class DevicePropertyCache {
Expand Down
35 changes: 35 additions & 0 deletions src/moore/causal_softmax/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#ifndef INFINI_OPS_MOORE_CAUSAL_SOFTMAX_KERNEL_H_
#define INFINI_OPS_MOORE_CAUSAL_SOFTMAX_KERNEL_H_

// clang-format off
#include <musa_runtime.h>
// clang-format on

// clang-format off
#include "moore/device_.h"
// clang-format on

#include "cuda/causal_softmax/kernel.h"

namespace infini::ops {

namespace causal_softmax {

struct MooreBackend {
using stream_t = musaStream_t;

static constexpr int max_block_size = 1024;
};

} // namespace causal_softmax

template <>
class Operator<CausalSoftmax, Device::Type::kMoore>
: public CudaCausalSoftmax<causal_softmax::MooreBackend> {
public:
using CudaCausalSoftmax<causal_softmax::MooreBackend>::CudaCausalSoftmax;
};

} // namespace infini::ops

#endif