-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathcaster_.h
More file actions
113 lines (96 loc) · 3.57 KB
/
caster_.h
File metadata and controls
113 lines (96 loc) · 3.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#ifndef INFINI_OPS_COMMON_CUDA_CASTER_H_
#define INFINI_OPS_COMMON_CUDA_CASTER_H_
#ifdef WITH_NVIDIA
#include <cuda_runtime.h>
#elif defined(WITH_ILUVATAR)
#include <cuda_runtime.h>
#elif defined(WITH_HYGON)
#include <cuda_runtime.h>
#elif defined(WITH_METAX)
#include <mcr/mc_runtime.h>
#elif defined(WITH_MOORE)
#include <musa_runtime.h>
#endif
#include "caster.h"
namespace infini::ops {
template <>
struct Caster<kDeviceType> {
template <typename Dst, typename Src>
__host__ __device__ static Dst Cast(Src&& x) {
static_assert(!std::is_reference_v<Dst>,
"`Cast` cannot return reference types");
using PureSrc = std::remove_cv_t<std::remove_reference_t<Src>>;
using PureDst = std::remove_cv_t<std::remove_reference_t<Dst>>;
if constexpr (std::is_same_v<PureSrc, PureDst>) {
return std::forward<Src>(x);
} else {
return HardwareCast<PureDst>(std::forward<Src>(x), PriorityHigh{});
}
}
private:
template <typename T>
using PureType = std::remove_cv_t<std::remove_reference_t<T>>;
template <typename T>
__host__ __device__ static constexpr float ToFloatHelper(T&& x) {
using PureSrc = PureType<T>;
if constexpr (IsBFloat16<PureSrc>) {
return __bfloat162float(x);
} else if constexpr (IsFP16<PureSrc>) {
return __half2float(x);
} else {
return static_cast<float>(std::forward<T>(x));
}
}
template <typename Dst>
__host__ __device__ static constexpr Dst FromFloatHelper(float f) {
using PureDst = PureType<Dst>;
if constexpr (IsBFloat16<PureDst>) {
return __float2bfloat16(f);
} else if constexpr (IsFP16<PureDst>) {
return __float2half(f);
} else {
return static_cast<Dst>(f);
}
}
// Priority tags for overload resolution.
struct PriorityLow {};
struct PriorityHigh : PriorityLow {};
// Fallback: lowest priority. This always matches if nothing else does.
template <typename Dst, typename Src>
__host__ __device__ static constexpr Dst HardwareCast(Src&& x, PriorityLow) {
return FromFloatHelper<Dst>(ToFloatHelper(std::forward<Src>(x)));
}
// Usage: `DEFINE_DIRECT_CAST(INTRINSIC, CONDITION)`.
#define DEFINE_DIRECT_CAST(INTRINSIC, ...) \
template <typename Dst, typename Src> \
__host__ __device__ static auto HardwareCast(Src x, PriorityHigh) \
-> std::enable_if_t<(__VA_ARGS__), \
decltype(INTRINSIC(std::declval<Src>()))> { \
return INTRINSIC(x); \
}
DEFINE_DIRECT_CAST(
__bfloat162int_rn,
std::is_same_v<PureType<Dst>, int>&& IsBFloat16<PureType<Src>>)
DEFINE_DIRECT_CAST(
__bfloat162short_rn,
std::is_same_v<PureType<Dst>, short>&& IsBFloat16<PureType<Src>>)
DEFINE_DIRECT_CAST(
__int2bfloat16_rn,
IsBFloat16<PureType<Dst>>&& std::is_same_v<PureType<Src>, int>)
DEFINE_DIRECT_CAST(__int2half_rn,
IsFP16<PureType<Dst>>&& std::is_same_v<PureType<Src>, int>)
DEFINE_DIRECT_CAST(
__double2bfloat16,
IsBFloat16<PureType<Dst>>&& std::is_same_v<PureType<Src>, double>)
DEFINE_DIRECT_CAST(
__double2half,
IsFP16<PureType<Dst>>&& std::is_same_v<PureType<Src>, double>)
DEFINE_DIRECT_CAST(__half, IsFP16<PureType<Dst>>&& IsBFloat16<PureType<Src>>)
#undef DEFINE_DIRECT_CAST
};
template <typename Dst, typename... Args>
__host__ __device__ __forceinline__ auto Cast(Args&&... args) {
return Caster<kDeviceType>::template Cast<Dst>(std::forward<Args>(args)...);
}
} // namespace infini::ops
#endif