-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathkernel_commons.h
More file actions
75 lines (58 loc) · 1.86 KB
/
kernel_commons.h
File metadata and controls
75 lines (58 loc) · 1.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#ifndef INFINI_OPS_COMMON_CUDA_KERNEL_COMMONS_H_
#define INFINI_OPS_COMMON_CUDA_KERNEL_COMMONS_H_
#include <type_traits>
#include "caster.h"
namespace infini::ops {
using AllCudaBlockSizes = List<128, 256, 512, 1024, 2048>;
template <typename Backend, typename = void>
struct BackendMaxBlockSize : std::integral_constant<int, 2048> {};
template <typename Backend>
struct BackendMaxBlockSize<Backend,
std::void_t<decltype(Backend::max_block_size)>>
: std::integral_constant<int, Backend::max_block_size> {};
template <int max_block_size>
struct SupportedCudaBlockSizes;
template <>
struct SupportedCudaBlockSizes<2048> {
using type = AllCudaBlockSizes;
};
template <>
struct SupportedCudaBlockSizes<1024> {
using type = List<128, 256, 512, 1024>;
};
template <>
struct SupportedCudaBlockSizes<512> {
using type = List<128, 256, 512>;
};
template <>
struct SupportedCudaBlockSizes<256> {
using type = List<128, 256>;
};
template <>
struct SupportedCudaBlockSizes<128> {
using type = List<128>;
};
template <int max_block_size>
using SupportedCudaBlockSizesType =
typename SupportedCudaBlockSizes<max_block_size>::type;
__forceinline__ __device__ __host__ size_t
IndexToOffset(size_t flat_index, size_t ndim, const size_t* shape,
const ptrdiff_t* strides) {
size_t res = 0;
for (size_t i = ndim; i-- > 0;) {
res += (flat_index % shape[i]) * strides[i];
flat_index /= shape[i];
}
return res;
}
// Selects the largest block size from `AllCudaBlockSizes` that does not exceed
// `max_threads_per_block`.
inline int ComputeOptimalBlockSize(int max_threads_per_block) {
if (max_threads_per_block >= 2048) return 2048;
if (max_threads_per_block >= 1024) return 1024;
if (max_threads_per_block >= 512) return 512;
if (max_threads_per_block >= 256) return 256;
return 128;
}
} // namespace infini::ops
#endif