|
6 | 6 | namespace infini_train::core { |
7 | 7 |
|
8 | 8 | /** |
9 | | - * Backend type mapping: DataType -> backend dispatch type |
| 9 | + * Backend type mapping: DataType -> backend-native dispatch type |
10 | 10 | * |
11 | 11 | * NativeScalar — maps framework low-precision scalar types (FP16/BF16) to |
12 | 12 | * backend-native scalar types (__half / __nv_bfloat16). |
13 | | - * Each backend specializes only the types it needs. |
| 13 | + * Primary template intentionally undefined. |
| 14 | + * Each backend specializes only the types it supports. |
14 | 15 | * |
15 | 16 | * BackendTypeMap — maps DataType to the C++ type used by kernels/dispatch. |
16 | | - * Falls back to the framework TypeMap for all types |
17 | | - * except FP16/BF16, which are routed through NativeScalar. |
| 17 | + * Primary template intentionally undefined — there is NO |
| 18 | + * default fallback to the framework TypeMap<DType>. |
18 | 19 | * |
19 | | - * Value-level conversion between framework scalars and native scalars is out |
20 | | - * of scope here; kernels use common::cuda::Cast<T> directly. |
| 20 | + * Backends must register dtypes explicitly: |
| 21 | + * - Standard types (int, float, double, ...): |
| 22 | + * call INFINI_REGISTER_STANDARD_BACKEND_TYPES(Dev) |
| 23 | + * once at file scope in the backend's dispatch header. |
| 24 | + * - Low-precision types (FP16, BF16): |
| 25 | + * specialize NativeScalar<Dev, infini_train::FP16/BF16>. |
| 26 | + * The generic partial specializations below then resolve |
| 27 | + * automatically via SFINAE-safe helper. |
| 28 | + * |
| 29 | + * If a backend does not register a dtype, HasMappedType_v returns false and |
| 30 | + * DispatchByTypeMap fires a clear static_assert at compile time. |
21 | 31 | */ |
22 | 32 |
|
23 | 33 | // ----------------------------------------------------------------------------- |
24 | | -// NativeScalar: framework scalar -> backend native scalar type mapping |
| 34 | +// NativeScalar: framework scalar -> backend native scalar |
25 | 35 | // Primary template intentionally undefined. |
26 | | -// Each backend specializes the scalar types it supports. |
27 | 36 | // ----------------------------------------------------------------------------- |
28 | 37 | template <Device::DeviceType Dev, typename Scalar> struct NativeScalar; |
29 | 38 |
|
30 | 39 | template <Device::DeviceType Dev, typename Scalar> using NativeScalar_t = typename NativeScalar<Dev, Scalar>::type; |
31 | 40 |
|
32 | 41 | // ----------------------------------------------------------------------------- |
33 | 42 | // BackendTypeMap: DataType -> backend dispatch type |
34 | | -// Primary template falls back to the framework TypeMap. |
35 | | -// FP16/BF16 are overridden to resolve through NativeScalar. |
| 43 | +// Primary template intentionally undefined — no TypeMap<DType> fallback. |
36 | 44 | // ----------------------------------------------------------------------------- |
37 | | -template <Device::DeviceType Dev, DataType DType> struct BackendTypeMap : infini_train::TypeMap<DType> {}; |
| 45 | +template <Device::DeviceType Dev, DataType DType> struct BackendTypeMap; |
38 | 46 |
|
39 | | -template <Device::DeviceType Dev> struct BackendTypeMap<Dev, DataType::kFLOAT16> { |
40 | | - using type = NativeScalar_t<Dev, infini_train::FP16>; |
41 | | -}; |
| 47 | +// ----------------------------------------------------------------------------- |
| 48 | +// SFINAE-safe helper for low-precision type routing. |
| 49 | +// When NativeScalar<Dev, Scalar> is undefined, this struct has no `type` |
| 50 | +// member, making HasMappedType_v<..., kFLOAT16/kBFLOAT16> return false and |
| 51 | +// triggering the static_assert in dispatch rather than an opaque hard error. |
| 52 | +// ----------------------------------------------------------------------------- |
| 53 | +namespace detail { |
| 54 | + |
| 55 | +template <Device::DeviceType Dev, typename Scalar, typename = void> |
| 56 | +struct BackendLowPrecisionTypeHelper {}; // no `type` member when NativeScalar absent |
42 | 57 |
|
43 | | -template <Device::DeviceType Dev> struct BackendTypeMap<Dev, DataType::kBFLOAT16> { |
44 | | - using type = NativeScalar_t<Dev, infini_train::BF16>; |
| 58 | +template <Device::DeviceType Dev, typename Scalar> |
| 59 | +struct BackendLowPrecisionTypeHelper<Dev, Scalar, std::void_t<typename NativeScalar<Dev, Scalar>::type>> { |
| 60 | + using type = typename NativeScalar<Dev, Scalar>::type; |
45 | 61 | }; |
46 | 62 |
|
| 63 | +} // namespace detail |
| 64 | + |
| 65 | +// Low-precision partial specializations: generic over Dev, resolved via NativeScalar. |
| 66 | +template <Device::DeviceType Dev> |
| 67 | +struct BackendTypeMap<Dev, DataType::kFLOAT16> : detail::BackendLowPrecisionTypeHelper<Dev, infini_train::FP16> {}; |
| 68 | + |
| 69 | +template <Device::DeviceType Dev> |
| 70 | +struct BackendTypeMap<Dev, DataType::kBFLOAT16> : detail::BackendLowPrecisionTypeHelper<Dev, infini_train::BF16> {}; |
| 71 | + |
47 | 72 | } // namespace infini_train::core |
| 73 | + |
| 74 | +// ----------------------------------------------------------------------------- |
| 75 | +// INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV) |
| 76 | +// |
| 77 | +// Explicitly registers the 10 standard (non-low-precision) dtypes for a backend |
| 78 | +// device. Invoke once at file scope (outside any namespace) in the backend's |
| 79 | +// dispatch header, e.g.: |
| 80 | +// |
| 81 | +// INFINI_REGISTER_STANDARD_BACKEND_TYPES(Device::DeviceType::kCUDA) |
| 82 | +// |
| 83 | +// FP16 and BF16 are NOT registered here — they are handled via NativeScalar. |
| 84 | +// ----------------------------------------------------------------------------- |
| 85 | +#define INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV) \ |
| 86 | + namespace infini_train::core { \ |
| 87 | + template <> struct BackendTypeMap<DEV, DataType::kUINT8> { \ |
| 88 | + using type = uint8_t; \ |
| 89 | + }; \ |
| 90 | + template <> struct BackendTypeMap<DEV, DataType::kINT8> { \ |
| 91 | + using type = int8_t; \ |
| 92 | + }; \ |
| 93 | + template <> struct BackendTypeMap<DEV, DataType::kUINT16> { \ |
| 94 | + using type = uint16_t; \ |
| 95 | + }; \ |
| 96 | + template <> struct BackendTypeMap<DEV, DataType::kINT16> { \ |
| 97 | + using type = int16_t; \ |
| 98 | + }; \ |
| 99 | + template <> struct BackendTypeMap<DEV, DataType::kUINT32> { \ |
| 100 | + using type = uint32_t; \ |
| 101 | + }; \ |
| 102 | + template <> struct BackendTypeMap<DEV, DataType::kINT32> { \ |
| 103 | + using type = int32_t; \ |
| 104 | + }; \ |
| 105 | + template <> struct BackendTypeMap<DEV, DataType::kUINT64> { \ |
| 106 | + using type = uint64_t; \ |
| 107 | + }; \ |
| 108 | + template <> struct BackendTypeMap<DEV, DataType::kINT64> { \ |
| 109 | + using type = int64_t; \ |
| 110 | + }; \ |
| 111 | + template <> struct BackendTypeMap<DEV, DataType::kFLOAT32> { \ |
| 112 | + using type = float; \ |
| 113 | + }; \ |
| 114 | + template <> struct BackendTypeMap<DEV, DataType::kFLOAT64> { \ |
| 115 | + using type = double; \ |
| 116 | + }; \ |
| 117 | + } /* namespace infini_train::core */ |
0 commit comments