-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathkernel.h
More file actions
58 lines (45 loc) · 1.84 KB
/
kernel.h
File metadata and controls
58 lines (45 loc) · 1.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#ifndef INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
#define INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
#include <cstdint>
#include "base/causal_softmax.h"
#include "cuda/causal_softmax/kernel.cuh"
#include "cuda/kernel_commons.h"
#include "data_type.h"
#include "dispatcher.h"
namespace infini::ops {
template <typename Backend>
class CudaCausalSoftmax : public CausalSoftmax {
public:
using CausalSoftmax::CausalSoftmax;
void operator()(const Tensor input, Tensor out) const override {
auto cuda_stream =
static_cast<typename Backend::stream_t>(stream_ ? stream_ : 0);
auto stride_input_batch = ndim_ == 3 ? input_strides_[0] : 0;
auto stride_input_row = input_strides_[ndim_ - 2];
auto stride_out_batch = ndim_ == 3 ? out_strides_[0] : 0;
auto stride_out_row = out_strides_[ndim_ - 2];
dim3 grid(static_cast<unsigned>(seq_len_),
static_cast<unsigned>(batch_size_));
if (out.dtype() != input.dtype()) {
std::abort();
}
int block_size = GetOptimalBlockSize();
DispatchFunc<ConcatType<List<DataType::kFloat32>, ReducedFloatTypes>,
AllCudaBlockSizes>(
// TODO: Output dtype should use the one passed in during construction.
{static_cast<int64_t>(out.dtype()), block_size},
[&](auto list_tag) {
using T = TypeMapType<ListGet<0>(list_tag)>;
constexpr int kBlockSize = ListGet<1>(list_tag);
CausalSoftmaxKernel<kBlockSize, T, float>
<<<grid, kBlockSize, 0, cuda_stream>>>(
reinterpret_cast<T*>(out.data()),
reinterpret_cast<const T*>(input.data()), batch_size_,
seq_len_, total_seq_len_, stride_out_batch, stride_out_row,
stride_input_batch, stride_input_row);
},
"CudaCausalSoftmax::operator()");
}
};
} // namespace infini::ops
#endif