Skip to content

Commit 6490093

Browse files
committed
refactor: remove old promotion codes implemented by WidestType_t
1 parent 0c205c6 commit 6490093

4 files changed

Lines changed: 6 additions & 108 deletions

File tree

infini_train/include/datatype.h

Lines changed: 0 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <cmath>
55
#include <cstdint>
66
#include <string>
7-
#include <type_traits>
87
#include <unordered_map>
98

109
namespace infini_train {
@@ -303,110 +302,9 @@ template <> struct DataTypeMap<BF16> {
303302
static constexpr DataType value = DataType::kBFLOAT16;
304303
};
305304

306-
// -----------------------------------------------------------------------------
307-
// Type traits extensions (framework fallback scalar semantics)
308-
// -----------------------------------------------------------------------------
309-
template <typename T> struct is_floating_point_ext : std::is_floating_point<T> {};
310-
311-
template <typename T> struct is_arithmetic_ext : std::is_arithmetic<T> {};
312-
313-
template <> struct is_floating_point_ext<BF16> : std::true_type {};
314-
template <> struct is_arithmetic_ext<BF16> : std::true_type {};
315-
316-
template <> struct is_floating_point_ext<FP16> : std::true_type {};
317-
template <> struct is_arithmetic_ext<FP16> : std::true_type {};
318-
319-
// -----------------------------------------------------------------------------
320-
// Promotion helpers (framework-level WidestType)
321-
// -----------------------------------------------------------------------------
322-
namespace detail {
323-
324-
template <typename T1, typename T2> struct LargerType {
325-
static constexpr size_t size1 = sizeof(T1);
326-
static constexpr size_t size2 = sizeof(T2);
327-
using type = std::conditional_t<(size1 >= size2), T1, T2>;
328-
};
329-
330-
template <> struct LargerType<BF16, FP16> {
331-
using type = float;
332-
};
333-
334-
template <> struct LargerType<FP16, BF16> {
335-
using type = float;
336-
};
337-
338-
/**
339-
* @brief Finds the first type in a parameter pack that satisfies the given predicate.
340-
* If no type matches, returns the last type in the pack (base case).
341-
*/
342-
template <template <typename> class Predicate, typename... Ts> struct FirstMatchingType;
343-
344-
template <template <typename> class Predicate, typename T> struct FirstMatchingType<Predicate, T> {
345-
using type = T;
346-
};
347-
348-
template <template <typename> class Predicate, typename T, typename... Ts>
349-
struct FirstMatchingType<Predicate, T, Ts...> {
350-
using type = std::conditional_t<Predicate<T>::value, T, typename FirstMatchingType<Predicate, Ts...>::type>;
351-
};
352-
353-
/**
354-
* @brief Recursively finds the widest type among those that satisfy a predicate.
355-
* Types not satisfying the predicate are ignored and don't affect the current maximum.
356-
*/
357-
template <template <typename> class Predicate, typename CurrentMax, typename... Ts> struct WidestTypeImpl;
358-
359-
template <template <typename> class Predicate, typename CurrentMax> struct WidestTypeImpl<Predicate, CurrentMax> {
360-
using type = CurrentMax;
361-
};
362-
363-
template <template <typename> class Predicate, typename CurrentMax, typename T, typename... Ts>
364-
struct WidestTypeImpl<Predicate, CurrentMax, T, Ts...> {
365-
using new_max = std::conditional_t<Predicate<T>::value, typename LargerType<CurrentMax, T>::type, CurrentMax>;
366-
using type = typename WidestTypeImpl<Predicate, new_max, Ts...>::type;
367-
};
368-
369-
template <template <typename> class Predicate, typename... Ts> struct MaxTypeBySizeWithPredicate {
370-
using first = typename FirstMatchingType<Predicate, Ts...>::type;
371-
using type = typename WidestTypeImpl<Predicate, first, Ts...>::type;
372-
};
373-
374-
} // namespace detail
375-
376-
/**
377-
* @brief Finds the widest/largest type according to a PyTorch-like dtype promotion rule among a pack of arithmetic
378-
* types.
379-
*
380-
* - If floating-point types are present, selects the largest floating-point type;
381-
* - Otherwise selects the largest integral type.
382-
* - If multiple integral types have the same size, precedence follows the list order.
383-
*
384-
* Note:
385-
* - FP16/BF16 are treated as floating-point.
386-
* - Mixed FP16 and BF16 promotes to float (32-bit).
387-
*/
388-
template <typename... Ts> struct WidestType {
389-
static_assert(sizeof...(Ts) > 0, "At least one type is required");
390-
static_assert((is_arithmetic_ext<Ts>::value && ...),
391-
"All types must be arithmetic or framework floating-point types (FP16/BF16)");
392-
393-
static constexpr bool has_float = (is_floating_point_ext<Ts>::value || ...);
394-
395-
using type =
396-
typename std::conditional_t<has_float, detail::MaxTypeBySizeWithPredicate<is_floating_point_ext, Ts...>,
397-
detail::MaxTypeBySizeWithPredicate<std::is_integral, Ts...>>::type;
398-
};
399-
400-
// Convenience alias
401-
template <typename... Ts> using WidestType_t = typename WidestType<Ts...>::type;
402-
403305
// =============================================================================
404306
// DataType-level promotion (pure enum → enum, no concrete/backend types)
405307
// =============================================================================
406-
// These facilities replace `DataTypeMap_v<WidestType_t<Ta, Tb>>` in CUDA
407-
// kernels, so that backend kernels never need to know about __half /
408-
// __nv_bfloat16 at promotion time.
409-
//
410308
// Rules (priority order):
411309
// 1. FP16 + BF16 → FLOAT32 (neither is a lossless superset of the other)
412310
// 2. Any float dominates any integer → keep the float type

infini_train/src/kernels/cpu/cast.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include "infini_train/include/dtype_dispatch.h"
66
#include "infini_train/include/tensor.h"
77

8-
#include "infini_train/src/core/cpu/cpu_dispatch.h"
8+
#include "infini_train/src/core/runtime/cpu/cpu_dispatch.h"
99

1010
namespace infini_train::kernels::cpu {
1111

infini_train/src/kernels/cpu/fill.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include "infini_train/include/dtype_dispatch.h"
66
#include "infini_train/include/tensor.h"
77

8-
#include "infini_train/src/core/cpu/cpu_dispatch.h"
8+
#include "infini_train/src/core/runtime/cpu/cpu_dispatch.h"
99

1010
namespace infini_train::kernels::cpu {
1111
void Fill(std::shared_ptr<Tensor> tensor, double value) {

infini_train/src/kernels/cuda/elementwise.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -861,17 +861,17 @@ BinaryBackward(const std::shared_ptr<Tensor> &grad_output, const std::shared_ptr
861861
switch (promoted_type) {
862862
DISPATCH_CASE(WRAP({
863863
if (needs_broadcast) {
864-
grad_a->Fill<float>(0.0f);
865-
grad_b->Fill<float>(0.0f);
864+
grad_a->Fill(0.0f);
865+
grad_b->Fill(0.0f);
866866
}
867867
LaunchBackward<256, float>(fn_a, fn_b, grad_a, grad_b, a_dims, b_dims, grad_output_promoted,
868868
a_promoted, b_promoted);
869869
}),
870870
DataType::kFLOAT32)
871871
DISPATCH_CASE(WRAP({
872872
if (needs_broadcast) {
873-
grad_a->Fill<nv_bfloat16>(0);
874-
grad_b->Fill<nv_bfloat16>(0);
873+
grad_a->Fill(0.0f);
874+
grad_b->Fill(0.0f);
875875
}
876876
LaunchBackward<256, nv_bfloat16>(fn_a, fn_b, grad_a, grad_b, a_dims, b_dims,
877877
grad_output_promoted, a_promoted, b_promoted);

0 commit comments

Comments
 (0)