Skip to content

Commit e791b1d

Browse files
committed
refactor: introduce backend-explicit dtype dispatch and DataType-level promotion
Replace the implicit TypeMap fallback in BackendTypeMap with explicit per-backend dtype registration (INFINI_REGISTER_STANDARD_BACKEND_TYPES), ensuring FP16/BF16 are only dispatched through backend-specific paths (DispatchCpuFunc/DispatchCudaFunc). Migrate CUDA kernel promotion from concrete-type WidestType_t to pure DataType enum-level PromoteDataTypes(), eliminating the need for backend scalar types at promotion time. Replace runtime kDataTypeToSize map with constexpr DTypeSize().
1 parent bec0f8c commit e791b1d

16 files changed

Lines changed: 363 additions & 94 deletions

File tree

infini_train/include/core/backend_type_map.h

Lines changed: 86 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,42 +6,112 @@
66
namespace infini_train::core {
77

88
/**
9-
* Backend type mapping: DataType -> backend dispatch type
9+
* Backend type mapping: DataType -> backend-native dispatch type
1010
*
1111
* NativeScalar — maps framework low-precision scalar types (FP16/BF16) to
1212
* backend-native scalar types (__half / __nv_bfloat16).
13-
* Each backend specializes only the types it needs.
13+
* Primary template intentionally undefined.
14+
* Each backend specializes only the types it supports.
1415
*
1516
* BackendTypeMap — maps DataType to the C++ type used by kernels/dispatch.
16-
* Falls back to the framework TypeMap for all types
17-
* except FP16/BF16, which are routed through NativeScalar.
17+
* Primary template intentionally undefined — there is NO
18+
* default fallback to the framework TypeMap<DType>.
1819
*
19-
* Value-level conversion between framework scalars and native scalars is out
20-
* of scope here; kernels use common::cuda::Cast<T> directly.
20+
* Backends must register dtypes explicitly:
21+
* - Standard types (int, float, double, ...):
22+
* call INFINI_REGISTER_STANDARD_BACKEND_TYPES(Dev)
23+
* once at file scope in the backend's dispatch header.
24+
* - Low-precision types (FP16, BF16):
25+
* specialize NativeScalar<Dev, infini_train::FP16/BF16>.
26+
* The generic partial specializations below then resolve
27+
* automatically via SFINAE-safe helper.
28+
*
29+
* If a backend does not register a dtype, HasMappedType_v returns false and
30+
* DispatchByTypeMap fires a clear static_assert at compile time.
2131
*/
2232

2333
// -----------------------------------------------------------------------------
24-
// NativeScalar: framework scalar -> backend native scalar type mapping
34+
// NativeScalar: framework scalar -> backend native scalar
2535
// Primary template intentionally undefined.
26-
// Each backend specializes the scalar types it supports.
2736
// -----------------------------------------------------------------------------
2837
template <Device::DeviceType Dev, typename Scalar> struct NativeScalar;
2938

3039
template <Device::DeviceType Dev, typename Scalar> using NativeScalar_t = typename NativeScalar<Dev, Scalar>::type;
3140

3241
// -----------------------------------------------------------------------------
3342
// BackendTypeMap: DataType -> backend dispatch type
34-
// Primary template falls back to the framework TypeMap.
35-
// FP16/BF16 are overridden to resolve through NativeScalar.
43+
// Primary template intentionally undefined — no TypeMap<DType> fallback.
3644
// -----------------------------------------------------------------------------
37-
template <Device::DeviceType Dev, DataType DType> struct BackendTypeMap : infini_train::TypeMap<DType> {};
45+
template <Device::DeviceType Dev, DataType DType> struct BackendTypeMap;
3846

39-
template <Device::DeviceType Dev> struct BackendTypeMap<Dev, DataType::kFLOAT16> {
40-
using type = NativeScalar_t<Dev, infini_train::FP16>;
41-
};
47+
// -----------------------------------------------------------------------------
48+
// SFINAE-safe helper for low-precision type routing.
49+
// When NativeScalar<Dev, Scalar> is undefined, this struct has no `type`
50+
// member, making HasMappedType_v<..., kFLOAT16/kBFLOAT16> return false and
51+
// triggering the static_assert in dispatch rather than an opaque hard error.
52+
// -----------------------------------------------------------------------------
53+
namespace detail {
54+
55+
template <Device::DeviceType Dev, typename Scalar, typename = void>
56+
struct BackendLowPrecisionTypeHelper {}; // no `type` member when NativeScalar absent
4257

43-
template <Device::DeviceType Dev> struct BackendTypeMap<Dev, DataType::kBFLOAT16> {
44-
using type = NativeScalar_t<Dev, infini_train::BF16>;
58+
template <Device::DeviceType Dev, typename Scalar>
59+
struct BackendLowPrecisionTypeHelper<Dev, Scalar, std::void_t<typename NativeScalar<Dev, Scalar>::type>> {
60+
using type = typename NativeScalar<Dev, Scalar>::type;
4561
};
4662

63+
} // namespace detail
64+
65+
// Low-precision partial specializations: generic over Dev, resolved via NativeScalar.
66+
template <Device::DeviceType Dev>
67+
struct BackendTypeMap<Dev, DataType::kFLOAT16> : detail::BackendLowPrecisionTypeHelper<Dev, infini_train::FP16> {};
68+
69+
template <Device::DeviceType Dev>
70+
struct BackendTypeMap<Dev, DataType::kBFLOAT16> : detail::BackendLowPrecisionTypeHelper<Dev, infini_train::BF16> {};
71+
4772
} // namespace infini_train::core
73+
74+
// -----------------------------------------------------------------------------
75+
// INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV)
76+
//
77+
// Explicitly registers the 10 standard (non-low-precision) dtypes for a backend
78+
// device. Invoke once at file scope (outside any namespace) in the backend's
79+
// dispatch header, e.g.:
80+
//
81+
// INFINI_REGISTER_STANDARD_BACKEND_TYPES(Device::DeviceType::kCUDA)
82+
//
83+
// FP16 and BF16 are NOT registered here — they are handled via NativeScalar.
84+
// -----------------------------------------------------------------------------
85+
#define INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV) \
86+
namespace infini_train::core { \
87+
template <> struct BackendTypeMap<DEV, DataType::kUINT8> { \
88+
using type = uint8_t; \
89+
}; \
90+
template <> struct BackendTypeMap<DEV, DataType::kINT8> { \
91+
using type = int8_t; \
92+
}; \
93+
template <> struct BackendTypeMap<DEV, DataType::kUINT16> { \
94+
using type = uint16_t; \
95+
}; \
96+
template <> struct BackendTypeMap<DEV, DataType::kINT16> { \
97+
using type = int16_t; \
98+
}; \
99+
template <> struct BackendTypeMap<DEV, DataType::kUINT32> { \
100+
using type = uint32_t; \
101+
}; \
102+
template <> struct BackendTypeMap<DEV, DataType::kINT32> { \
103+
using type = int32_t; \
104+
}; \
105+
template <> struct BackendTypeMap<DEV, DataType::kUINT64> { \
106+
using type = uint64_t; \
107+
}; \
108+
template <> struct BackendTypeMap<DEV, DataType::kINT64> { \
109+
using type = int64_t; \
110+
}; \
111+
template <> struct BackendTypeMap<DEV, DataType::kFLOAT32> { \
112+
using type = float; \
113+
}; \
114+
template <> struct BackendTypeMap<DEV, DataType::kFLOAT64> { \
115+
using type = double; \
116+
}; \
117+
} /* namespace infini_train::core */

infini_train/include/datatype.h

Lines changed: 96 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,35 @@ enum class DataType : int8_t {
212212
kFLOAT64,
213213
};
214214

215-
inline const std::unordered_map<DataType, size_t> kDataTypeToSize = {
216-
{DataType::kUINT8, 1}, {DataType::kINT8, 1}, {DataType::kUINT16, 2}, {DataType::kINT16, 2},
217-
{DataType::kUINT32, 4}, {DataType::kINT32, 4}, {DataType::kUINT64, 8}, {DataType::kINT64, 8},
218-
{DataType::kBFLOAT16, 2}, {DataType::kFLOAT16, 2}, {DataType::kFLOAT32, 4}, {DataType::kFLOAT64, 8},
219-
};
215+
constexpr size_t DTypeSize(DataType data_type) {
216+
switch (data_type) {
217+
case DataType::kUINT8:
218+
return 1;
219+
case DataType::kINT8:
220+
return 1;
221+
case DataType::kUINT16:
222+
return 2;
223+
case DataType::kINT16:
224+
return 2;
225+
case DataType::kUINT32:
226+
return 4;
227+
case DataType::kINT32:
228+
return 4;
229+
case DataType::kUINT64:
230+
return 8;
231+
case DataType::kINT64:
232+
return 8;
233+
case DataType::kBFLOAT16:
234+
return 2;
235+
case DataType::kFLOAT16:
236+
return 2;
237+
case DataType::kFLOAT32:
238+
return 4;
239+
case DataType::kFLOAT64:
240+
return 8;
241+
}
242+
return 0; // unreachable
243+
}
220244

221245
inline const std::unordered_map<DataType, std::string> kDataTypeToDesc = {
222246
{DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"}, {DataType::kUINT16, "uint16"},
@@ -261,13 +285,24 @@ DEFINE_DEFAULT_DATA_TYPE_MAPPING(kUINT32, uint32_t)
261285
DEFINE_DEFAULT_DATA_TYPE_MAPPING(kINT32, int32_t)
262286
DEFINE_DEFAULT_DATA_TYPE_MAPPING(kUINT64, uint64_t)
263287
DEFINE_DEFAULT_DATA_TYPE_MAPPING(kINT64, int64_t)
264-
DEFINE_DEFAULT_DATA_TYPE_MAPPING(kBFLOAT16, BF16)
265-
DEFINE_DEFAULT_DATA_TYPE_MAPPING(kFLOAT16, FP16)
266288
DEFINE_DEFAULT_DATA_TYPE_MAPPING(kFLOAT32, float)
267289
DEFINE_DEFAULT_DATA_TYPE_MAPPING(kFLOAT64, double)
268290

269291
#undef DEFINE_DEFAULT_DATA_TYPE_MAPPING
270292

293+
// ---------------------------------------------------------------------------
294+
// Low-precision types: reverse mapping ONLY (DataTypeMap).
295+
// TypeMap<kFLOAT16> / TypeMap<kBFLOAT16> are intentionally NOT defined here.
296+
// Backend TypeMaps must explicitly provide these mappings; the default TypeMap
297+
// will static_assert at compile time if dispatch reaches an unmapped dtype.
298+
// ---------------------------------------------------------------------------
299+
template <> struct DataTypeMap<FP16> {
300+
static constexpr DataType value = DataType::kFLOAT16;
301+
};
302+
template <> struct DataTypeMap<BF16> {
303+
static constexpr DataType value = DataType::kBFLOAT16;
304+
};
305+
271306
// -----------------------------------------------------------------------------
272307
// Type traits extensions (framework fallback scalar semantics)
273308
// -----------------------------------------------------------------------------
@@ -365,4 +400,58 @@ template <typename... Ts> struct WidestType {
365400
// Convenience alias
366401
template <typename... Ts> using WidestType_t = typename WidestType<Ts...>::type;
367402

403+
// =============================================================================
404+
// DataType-level promotion (pure enum → enum, no concrete/backend types)
405+
// =============================================================================
406+
// These facilities replace `DataTypeMap_v<WidestType_t<Ta, Tb>>` in CUDA
407+
// kernels, so that backend kernels never need to know about __half /
408+
// __nv_bfloat16 at promotion time.
409+
//
410+
// Rules (priority order):
411+
// 1. FP16 + BF16 → FLOAT32 (neither is a lossless superset of the other)
412+
// 2. Any float dominates any integer → keep the float type
413+
// 3. Same category (float-float or int-int) → wider byte size wins
414+
// =============================================================================
415+
416+
/// Returns true for floating-point DataTypes (FP16, BF16, FP32, FP64).
417+
constexpr bool IsFloatingPointDType(DataType dt) {
418+
return dt == DataType::kFLOAT16 || dt == DataType::kBFLOAT16 || dt == DataType::kFLOAT32
419+
|| dt == DataType::kFLOAT64;
420+
}
421+
422+
/// Binary DataType promotion. Safe to call in both host and device code.
423+
constexpr DataType PromoteDataTypes(DataType a, DataType b) {
424+
if (a == b) {
425+
return a;
426+
}
427+
428+
// Rule 1: FP16 ↔ BF16 — no lossless path, promote to FP32
429+
if ((a == DataType::kFLOAT16 && b == DataType::kBFLOAT16)
430+
|| (a == DataType::kBFLOAT16 && b == DataType::kFLOAT16)) {
431+
return DataType::kFLOAT32;
432+
}
433+
434+
const bool a_fp = IsFloatingPointDType(a);
435+
const bool b_fp = IsFloatingPointDType(b);
436+
437+
// Rule 2: float beats integer
438+
if (a_fp && !b_fp) {
439+
return a;
440+
}
441+
if (b_fp && !a_fp) {
442+
return b;
443+
}
444+
445+
// Rule 3: same category — wider wins
446+
return DTypeSize(a) >= DTypeSize(b) ? a : b;
447+
}
448+
449+
/// Compile-time binary promotion: DataTypePromotion<A, B>::value
450+
template <DataType A, DataType B> struct DataTypePromotion {
451+
static constexpr DataType value = PromoteDataTypes(A, B);
452+
};
453+
454+
/// Convenience variable template
455+
template <DataType A, DataType B> inline constexpr DataType DataTypePromotion_v = DataTypePromotion<A, B>::value;
456+
368457
} // namespace infini_train

infini_train/include/dtype_dispatch.h

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,15 +203,39 @@ template <typename T, typename... Ts> inline constexpr bool IsTypeInList = (std:
203203

204204
template <template <DataType> class TypeMap, DataType DType> using MappedType_t = typename TypeMap<DType>::type;
205205

206+
// -----------------------------------------------------------------------------
207+
// Detection trait: does TypeMap<DType> have a nested `type` alias?
208+
// Returns false (instead of a hard error) when the primary template is
209+
// undefined or the specialization intentionally omits `type`.
210+
// -----------------------------------------------------------------------------
211+
namespace detail {
212+
template <template <DataType> class TypeMap, DataType DType, typename = void> struct HasMappedType : std::false_type {};
213+
214+
template <template <DataType> class TypeMap, DataType DType>
215+
struct HasMappedType<TypeMap, DType, std::void_t<typename TypeMap<DType>::type>> : std::true_type {};
216+
} // namespace detail
217+
218+
template <template <DataType> class TypeMap, DataType DType>
219+
inline constexpr bool HasMappedType_v = detail::HasMappedType<TypeMap, DType>::value;
220+
206221
// -----------------------------------------------------------------------------
207222
// Generic single-dtype dispatch by custom type map
208223
// -----------------------------------------------------------------------------
224+
// Membership is checked by DataType (not by mapped C++ type) to avoid
225+
// premature instantiation of TypeMap<DType> for every AllowedDType.
226+
// After confirming DType is in the allowed set, a static_assert verifies
227+
// that TypeMap actually provides a mapping; only then is MappedType_t used.
228+
// -----------------------------------------------------------------------------
209229
template <template <DataType> class TypeMap, DataType... AllowedDTypes, typename Functor, typename... Args>
210230
auto DispatchByTypeMap(DataType dtype, Functor &&func, std::string_view context_identifier = "", Args &&...args) {
211231
switch (dtype) {
212232
#define CASE_FOR_TYPE(DType) \
213233
case DType: { \
214-
if constexpr (IsTypeInList<MappedType_t<TypeMap, DType>, MappedType_t<TypeMap, AllowedDTypes>...>) { \
234+
if constexpr (IsDataTypeInList_v<DType, DataTypeList<AllowedDTypes...>>) { \
235+
static_assert(HasMappedType_v<TypeMap, DType>, \
236+
"TypeMap does not provide explicit mapping for this dtype. " \
237+
"If this is a backend dispatch, register the dtype in the backend TypeMap; " \
238+
"if this is DispatchFunc, the dtype is not supported by the default TypeMap."); \
215239
return std::forward<Functor>(func).template operator()<MappedType_t<TypeMap, DType>>( \
216240
std::forward<Args>(args)...); \
217241
} else { \
@@ -257,6 +281,10 @@ struct TypeMapDispatcher {
257281
#define CASE_FOR_TYPE(DType) \
258282
case DType: \
259283
if constexpr (IsDataTypeInList_v<DType, CurrentList>) { \
284+
static_assert(HasMappedType_v<TypeMap, DType>, \
285+
"TypeMap does not provide explicit mapping for this dtype. " \
286+
"If this is a backend dispatch, register the dtype in the backend TypeMap; " \
287+
"if this is DispatchFunc, the dtype is not supported by the default TypeMap."); \
260288
using T = MappedType_t<TypeMap, DType>; \
261289
return TypeMapDispatcher<TypeMap, Index + 1, AllowedListTuple, ResolvedTypes..., T>::call( \
262290
dtypes, std::forward<Functor>(func), context_identifier, std::forward<Args>(args)...); \
@@ -309,6 +337,10 @@ auto DispatchByTypeMap(const std::vector<DataType> &dtypes, Functor &&func, std:
309337
// -----------------------------------------------------------------------------
310338
// Default framework dispatch using TypeMap
311339
// -----------------------------------------------------------------------------
340+
// TypeMap only covers standard types (int/uint/float32/float64).
341+
// Low-precision types (FP16/BF16) are intentionally unmapped — use a
342+
// backend-specific dispatch (DispatchCudaFunc, DispatchCpuFunc, …) instead.
343+
// -----------------------------------------------------------------------------
312344
template <DataType... AllowedDTypes, typename Functor, typename... Args>
313345
auto DispatchFunc(DataType dtype, Functor &&func, std::string_view context_identifier = "", Args &&...args) {
314346
return DispatchByTypeMap<TypeMap, AllowedDTypes...>(dtype, std::forward<Functor>(func), context_identifier,

0 commit comments

Comments
 (0)