Skip to content

Commit ee196dc

Browse files
committed
refactor(maca): adapt MACA kernels to new dtype dispatch and Scalar APIs
Port MACA backend to master's backend-explicit dtype registration: - Add src/core/runtime/maca/maca_dispatch.h: register __half / __maca_bfloat16 via BackendTypeMap<kMACA, kFLOAT16/kBFLOAT16>, declare INFINI_REGISTER_STANDARD_BACKEND_TYPES(kMACA), and expose DispatchMacaFunc / MacaTypeMap mirroring the CUDA side. - Replace every DispatchFunc<...>/WidestType_t/DataTypeMap_v site across 18 MACA kernels with DispatchMacaFunc / PromoteDataTypes. - Replace Tensor::Fill<T>(0) template calls with Fill(0) to match the new Scalar-taking Tensor::Fill API. - fill.maca: route Scalar::to<T> through common::maca::Cast<T>(scalar.to<float>()) for __maca_bfloat16/__half to avoid ambiguous static_cast from integer Scalar kinds (see scalar.h TODO).
1 parent 91d6ad3 commit ee196dc

24 files changed

Lines changed: 265 additions & 108 deletions

CMakeLists.txt

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,23 @@ option(USE_OMP "Use OpenMP as backend for Eigen" ON)
99
option(USE_NCCL "Build project for distributed running on CUDA using NCCL" ON)
1010
option(USE_MCCL "Build project for distributed running on MACA using MCCL" ON)
1111

12-
project(infini_train VERSION 0.5.0 LANGUAGES CXX)
13-
14-
# Switch to mxcc after project() so that third-party libs (glog, gflags) are
15-
# configured with the host compiler and their feature-detection checks pass.
12+
# ------------------------------------------------------------------------------
13+
# MACA toolchain override (must happen before project())
14+
# ------------------------------------------------------------------------------
15+
# When targeting MetaX MACA, the C/C++ compiler must be mxcc so that .maca
16+
# sources and device code can be compiled by the MACA toolchain.
1617
if(USE_MACA)
1718
set(MACA_PATH $ENV{MACA_PATH})
1819
if(NOT MACA_PATH)
19-
message(FATAL_ERROR "USE_MACA=ON but environment variable MACA_PATH is not set.")
20+
message(FATAL_ERROR "USE_MACA=ON but environment variable MACA_PATH is not set. "
21+
"Please export MACA_PATH (e.g. /opt/maca) before configuring.")
2022
endif()
2123
set(CMAKE_C_COMPILER "${MACA_PATH}/mxgpu_llvm/bin/mxcc")
2224
set(CMAKE_CXX_COMPILER "${MACA_PATH}/mxgpu_llvm/bin/mxcc")
2325
endif()
2426

27+
project(infini_train VERSION 0.5.0 LANGUAGES CXX)
28+
2529
set(CMAKE_CXX_STANDARD 20)
2630
set(CMAKE_CXX_STANDARD_REQUIRED ON)
2731
set(CMAKE_CXX_EXTENSIONS OFF)
@@ -41,8 +45,45 @@ include_directories(${gflags_SOURCE_DIR}/include)
4145
set(WITH_GFLAGS OFF CACHE BOOL "Disable glog finding system gflags" FORCE)
4246
set(WITH_GTEST OFF CACHE BOOL "Disable glog finding system gtest" FORCE)
4347
set(BUILD_TESTING OFF CACHE BOOL "Disable glog unit tests" FORCE)
48+
# Build glog as a static lib so its symbols are always visible at link time.
49+
# Under mxcc the default symbol visibility is hidden, which causes the shared
50+
# libglog.so to export no symbols and produces "undefined reference" errors.
4451
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build glog as static library" FORCE)
4552

53+
# Under MACA/mxcc, cmake's feature-detection test compilations do not find
54+
# standard POSIX system headers (mxcc has a non-standard sysroot probe path).
55+
# Pre-set glog's HAVE_* cache variables so that glog skips its fallback type /
56+
# symbol definitions, which would otherwise conflict with the real system
57+
# headers during the actual build.
58+
if(USE_MACA)
59+
set(HAVE_SYS_TYPES_H 1 CACHE INTERNAL "")
60+
set(HAVE_UNISTD_H 1 CACHE INTERNAL "")
61+
set(HAVE_DLFCN_H 1 CACHE INTERNAL "")
62+
set(HAVE_GLOB_H 1 CACHE INTERNAL "")
63+
set(HAVE_PWD_H 1 CACHE INTERNAL "")
64+
set(HAVE_SYS_TIME_H 1 CACHE INTERNAL "")
65+
set(HAVE_SYS_UTSNAME_H 1 CACHE INTERNAL "")
66+
set(HAVE_SYS_WAIT_H 1 CACHE INTERNAL "")
67+
set(HAVE_SYS_SYSCALL_H 1 CACHE INTERNAL "")
68+
set(HAVE_SYSLOG_H 1 CACHE INTERNAL "")
69+
set(HAVE_UCONTEXT_H 1 CACHE INTERNAL "")
70+
# check_type_size() uses two internal variables: the size value and a sentinel
71+
# "HAVE_HAVE_<VAR>" that marks the check as done. Pre-setting only the value
72+
# is insufficient — the sentinel must also be set so the check skips entirely.
73+
set(HAVE_MODE_T 4 CACHE INTERNAL "") # 4 bytes on Linux
74+
set(HAVE_HAVE_MODE_T TRUE CACHE INTERNAL "")
75+
set(HAVE_SSIZE_T 8 CACHE INTERNAL "") # 8 bytes on 64-bit Linux
76+
set(HAVE_HAVE_SSIZE_T TRUE CACHE INTERNAL "")
77+
set(HAVE_PREAD 1 CACHE INTERNAL "")
78+
set(HAVE_PWRITE 1 CACHE INTERNAL "")
79+
set(HAVE_POSIX_FADVISE 1 CACHE INTERNAL "")
80+
set(HAVE_SIGACTION 1 CACHE INTERNAL "")
81+
set(HAVE_SIGALTSTACK 1 CACHE INTERNAL "")
82+
set(HAVE_FCNTL 1 CACHE INTERNAL "")
83+
set(HAVE_DLADDR 1 CACHE INTERNAL "")
84+
set(HAVE___CXA_DEMANGLE 1 CACHE INTERNAL "")
85+
endif()
86+
4687
add_subdirectory(third_party/glog)
4788
include_directories(${glog_SOURCE_DIR}/src)
4889

example/gpt2/main.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
#ifdef PROFILE_MODE
3030
#include "infini_train/include/profiler.h"
3131
#endif
32+
#ifdef USE_MACA
33+
#include "infini_train/src/core/runtime/maca/maca_guard_impl.h"
34+
#endif
3235
#include "infini_train/include/nn/parallel/utils.h"
3336
#include "infini_train/include/utils/global_module_hook_registry.h"
3437
#include "infini_train/include/utils/precision_check_config.h"
@@ -452,12 +455,28 @@ void Train(const nn::parallel::Rank &rank) {
452455
Profiler::Instance().Report("gpt2.report", Profiler::SortBy::DeviceTimePercentage);
453456
Profiler::Instance().PrintRecords("gpt2.records.log");
454457
#endif
458+
459+
// On MACA, flush all pending mcFreeAsync operations so that ATU entries for
460+
// activation/gradient tensors from this step are released before the next
461+
// forward pass begins. Without this, the ATU (address-translation unit)
462+
// accumulates deferred frees across steps and becomes full, causing
463+
// xnack(0x8) ATU-fault crashes in CastKernel and other large-tensor kernels.
464+
if (device.type() == Device::DeviceType::kMACA) {
465+
impl->SynchronizeDevice(device);
466+
}
455467
}
456468

457469
int main(int argc, char *argv[]) {
458470
gflags::ParseCommandLineFlags(&argc, &argv, true);
459471
google::InitGoogleLogging(argv[0]);
460472

473+
// On MACA, when TP > 1 disable P2P to prevent MCCL communication-ordering
474+
// deadlocks and P2P teardown crashes. Must be set before any mcclCommInitAll
475+
// call (i.e. before threads that create ProcessGroups are spawned).
476+
if (FLAGS_device == kDeviceMACA && FLAGS_tensor_parallel > 1) {
477+
setenv("MACA_P2P_DISABLE", "1", 1);
478+
}
479+
461480
auto precision_config = utils::PrecisionCheckConfig::Parse(FLAGS_precision_check);
462481
nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
463482
FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel);

example/llama3/main.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,12 +427,28 @@ void Train(const nn::parallel::Rank &rank) {
427427
Profiler::Instance().Report("llama3.report", Profiler::SortBy::DeviceTimePercentage);
428428
Profiler::Instance().PrintRecords("llama3.records.log");
429429
#endif
430+
431+
// On MACA, flush all pending mcFreeAsync operations so that ATU entries for
432+
// activation/gradient tensors from this step are released before the next
433+
// forward pass begins. Without this, the ATU (address-translation unit)
434+
// accumulates deferred frees across steps and becomes full, causing
435+
// xnack(0x8) ATU-fault crashes in CastKernel and other large-tensor kernels.
436+
if (device.type() == Device::DeviceType::kMACA) {
437+
impl->SynchronizeDevice(device);
438+
}
430439
}
431440

432441
int main(int argc, char *argv[]) {
433442
gflags::ParseCommandLineFlags(&argc, &argv, true);
434443
google::InitGoogleLogging(argv[0]);
435444

445+
// On MACA, when TP > 1 disable P2P to prevent MCCL communication-ordering
446+
// deadlocks and P2P teardown crashes. Must be set before any mcclCommInitAll
447+
// call (i.e. before threads that create ProcessGroups are spawned).
448+
if (FLAGS_device == kDeviceMACA && FLAGS_tensor_parallel > 1) {
449+
setenv("MACA_P2P_DISABLE", "1", 1);
450+
}
451+
436452
auto precision_config = utils::PrecisionCheckConfig::Parse(FLAGS_precision_check);
437453
nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
438454
FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel);
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#pragma once
2+
3+
#include <utility>
4+
#include <vector>
5+
6+
#include <common/maca_bfloat16.h>
7+
#include <common/maca_fp16.h>
8+
9+
#include "infini_train/include/core/backend_type_map.h"
10+
#include "infini_train/include/dtype_dispatch.h"
11+
12+
// -----------------------------------------------------------------------------
13+
// MACA low-precision BackendTypeMap specializations:
14+
// FP16 -> __half, BF16 -> __maca_bfloat16
15+
// -----------------------------------------------------------------------------
16+
namespace infini_train::core {
17+
template <> struct BackendTypeMap<Device::DeviceType::kMACA, DataType::kFLOAT16> {
18+
using type = __half;
19+
};
20+
21+
template <> struct BackendTypeMap<Device::DeviceType::kMACA, DataType::kBFLOAT16> {
22+
using type = __maca_bfloat16;
23+
};
24+
} // namespace infini_train::core
25+
26+
// Register all standard (non-low-precision) dtypes for the MACA backend.
27+
// FP16/BF16 are registered explicitly above with their MACA-native scalar types.
28+
INFINI_REGISTER_STANDARD_BACKEND_TYPES(infini_train::Device::DeviceType::kMACA)
29+
30+
namespace infini_train::core::maca {
31+
32+
template <DataType DType> struct MacaTypeMap : BackendTypeMap<Device::DeviceType::kMACA, DType> {};
33+
34+
// -----------------------------------------------------------------------------
35+
// MACA dispatch helpers
36+
// -----------------------------------------------------------------------------
37+
38+
template <DataType... AllowedDTypes, typename Functor, typename... Args>
39+
auto DispatchMacaFunc(DataType dtype, Functor &&func, std::string_view context_identifier = "", Args &&...args) {
40+
return infini_train::DispatchByTypeMap<MacaTypeMap, AllowedDTypes...>(
41+
dtype, std::forward<Functor>(func), context_identifier, std::forward<Args>(args)...);
42+
}
43+
44+
template <typename... AllowedTypeLists, typename Functor, typename... Args>
45+
auto DispatchMacaFunc(const std::vector<DataType> &dtypes, Functor &&func, std::string_view context_identifier = "",
46+
Args &&...args) {
47+
return infini_train::DispatchByTypeMap<MacaTypeMap, AllowedTypeLists...>(
48+
dtypes, std::forward<Functor>(func), context_identifier, std::forward<Args>(args)...);
49+
}
50+
51+
} // namespace infini_train::core::maca

infini_train/src/kernels/maca/accumulate_grad.maca

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "infini_train/include/dispatcher.h"
77
#include "infini_train/include/tensor.h"
88

9+
#include "infini_train/src/core/runtime/maca/maca_dispatch.h"
910
#include "infini_train/src/core/runtime/maca/maca_runtime_common.h"
1011

1112
namespace infini_train::kernels::maca {
@@ -29,7 +30,7 @@ void AccumulateGrad(const std::shared_ptr<Tensor> &gradient, float rate, const s
2930
infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device))
3031
->maca_stream();
3132

32-
DispatchFunc<INFINI_ALL_FLOATING_TYPES>(
33+
core::maca::DispatchMacaFunc<INFINI_ALL_FLOATING_TYPES>(
3334
gradient->Dtype(),
3435
[=]<typename T>() {
3536
AccumulateGradKernel<<<num_blocks, threads_per_block, 0, maca_stream>>>(
@@ -73,7 +74,7 @@ void AdamAccumulateGrad(const std::shared_ptr<Tensor> &grad, const std::shared_p
7374
infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device))
7475
->maca_stream();
7576

76-
DispatchFunc<INFINI_ALL_FLOATING_TYPES>(
77+
core::maca::DispatchMacaFunc<INFINI_ALL_FLOATING_TYPES>(
7778
grad->Dtype(),
7879
[=]<typename T>() {
7980
AdamAccumulateGradKernel<<<num_blocks, threads_per_block, 0, maca_stream>>>(

infini_train/src/kernels/maca/cast.maca

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "infini_train/include/dispatcher.h"
99
#include "infini_train/include/tensor.h"
1010

11+
#include "infini_train/src/core/runtime/maca/maca_dispatch.h"
1112
#include "infini_train/src/core/runtime/maca/maca_runtime_common.h"
1213

1314
namespace infini_train::kernels::maca {
@@ -33,7 +34,7 @@ std::shared_ptr<Tensor> Cast(std::shared_ptr<Tensor> input, DataType dtype) {
3334
dim3 grid_dims(CEIL_DIV(num_elements, block_dims.x));
3435
const size_t step = grid_dims.x * block_dims.x;
3536

36-
DispatchFunc<DataTypeList<INFINI_ALL_TYPES>, DataTypeList<INFINI_ALL_TYPES>>(
37+
core::maca::DispatchMacaFunc<DataTypeList<INFINI_ALL_TYPES>, DataTypeList<INFINI_ALL_TYPES>>(
3738
{dtype, input->Dtype()},
3839
[=]<typename Tdst, typename Tsrc>() {
3940
auto dst = static_cast<Tdst *>(dst_tensor->DataPtr());

infini_train/src/kernels/maca/comm.maca

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ std::vector<std::shared_ptr<Tensor>> ReduceAddCoalesced(const std::vector<std::v
2929
std::vector<std::vector<std::shared_ptr<Tensor>>> to_destination_grads;
3030
for (int i = 0; i < grads[0].size(); ++i) {
3131
outputs.emplace_back(std::make_shared<Tensor>(grads[0][i]->Dims(), grads[0][i]->Dtype(), destination));
32-
outputs[i]->Fill<float>(0.0);
32+
outputs[i]->Fill(0.0);
3333
}
3434
for (int i = 0; i < grads.size(); ++i) {
3535
to_destination_grads.push_back(std::vector<std::shared_ptr<Tensor>>());

infini_train/src/kernels/maca/concat.maca

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "infini_train/include/dispatcher.h"
1212
#include "infini_train/include/tensor.h"
1313

14+
#include "infini_train/src/core/runtime/maca/maca_dispatch.h"
1415
#include "infini_train/src/core/runtime/maca/maca_runtime_common.h"
1516

1617
namespace infini_train::kernels::maca {
@@ -102,7 +103,7 @@ std::shared_ptr<Tensor> ConcatForward(const std::vector<std::shared_ptr<Tensor>>
102103
int threads_per_block = 256;
103104
int num_blocks = static_cast<int>((total + threads_per_block - 1) / threads_per_block);
104105

105-
DispatchFunc<INFINI_ALL_TYPES>(
106+
core::maca::DispatchMacaFunc<INFINI_ALL_TYPES>(
106107
dtype,
107108
[=, &inputs, &host_offsets]<typename T>() {
108109
std::vector<const T *> host_input_ptrs;
@@ -185,8 +186,8 @@ std::vector<std::shared_ptr<Tensor>> ConcatBackward(const std::shared_ptr<Tensor
185186
grads.reserve(input_dims_list.size());
186187
for (const auto &dvec : input_dims_list) {
187188
auto t = std::make_shared<Tensor>(dvec, dtype, device);
188-
DispatchFunc<INFINI_ALL_TYPES>(
189-
dtype, [=]<typename T>() { t->Fill<T>(0); }, "MACA ConcatBackward");
189+
core::maca::DispatchMacaFunc<INFINI_ALL_TYPES>(
190+
dtype, [=]<typename T>() { t->Fill(0); }, "MACA ConcatBackward");
190191
grads.push_back(t);
191192
}
192193

@@ -208,7 +209,7 @@ std::vector<std::shared_ptr<Tensor>> ConcatBackward(const std::shared_ptr<Tensor
208209
int threads_per_block = 256;
209210
int num_blocks = static_cast<int>((total + threads_per_block - 1) / threads_per_block);
210211

211-
DispatchFunc<INFINI_ALL_TYPES>(
212+
core::maca::DispatchMacaFunc<INFINI_ALL_TYPES>(
212213
dtype,
213214
[=, &grads, &host_offsets]<typename T>() {
214215
std::vector<T *> host_ptrs;

infini_train/src/kernels/maca/cross_entropy.maca

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "infini_train/include/dispatcher.h"
1313
#include "infini_train/include/tensor.h"
1414

15+
#include "infini_train/src/core/runtime/maca/maca_dispatch.h"
1516
#include "infini_train/src/core/runtime/maca/maca_runtime_common.h"
1617

1718
namespace infini_train::kernels::maca {
@@ -91,7 +92,7 @@ std::shared_ptr<Tensor> CrossEntropyForward(const std::shared_ptr<Tensor> &input
9192
infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device))
9293
->maca_stream();
9394

94-
return DispatchFunc<DataTypeList<DataType::kUINT8, DataType::kINT64>, DataTypeList<INFINI_ALL_FLOATING_TYPES>>(
95+
return core::maca::DispatchMacaFunc<DataTypeList<DataType::kUINT8, DataType::kINT64>, DataTypeList<INFINI_ALL_FLOATING_TYPES>>(
9596
{target->Dtype(), input->Dtype()},
9697
[=]<typename Ttarget, typename Tinput>() {
9798
const Ttarget *target_ptr = static_cast<const Ttarget *>(target->DataPtr());
@@ -198,10 +199,10 @@ std::shared_ptr<Tensor> CrossEntropyBackward(const std::shared_ptr<Tensor> &inpu
198199
infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device))
199200
->maca_stream();
200201

201-
DispatchFunc<DataTypeList<DataType::kUINT8, DataType::kINT64>, DataTypeList<INFINI_ALL_FLOATING_TYPES>>(
202+
core::maca::DispatchMacaFunc<DataTypeList<DataType::kUINT8, DataType::kINT64>, DataTypeList<INFINI_ALL_FLOATING_TYPES>>(
202203
{target->Dtype(), input_casted->Dtype()},
203204
[=]<typename Ttarget, typename Tinput>() {
204-
grad_input->Fill<Tinput>(0);
205+
grad_input->Fill(0);
205206
const Tinput *output_grad_ptr = static_cast<const Tinput *>(grad_output->DataPtr());
206207
const Ttarget *target_ptr = static_cast<const Ttarget *>(target->DataPtr());
207208
const Tinput *input_ptr = static_cast<const Tinput *>(input_casted->DataPtr());

infini_train/src/kernels/maca/elementwise.maca

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "infini_train/include/dispatcher.h"
99
#include "infini_train/include/tensor.h"
1010

11+
#include "infini_train/src/core/runtime/maca/maca_dispatch.h"
1112
#include "infini_train/src/core/runtime/maca/maca_runtime_common.h"
1213

1314
namespace infini_train::kernels::maca {
@@ -766,9 +767,7 @@ std::shared_ptr<Tensor> UnaryBackward(const std::shared_ptr<Tensor> &grad_output
766767
Func unary_fn) {
767768
auto dtype = grad_output->Dtype();
768769
auto a_dtype = a ? a->Dtype() : dtype;
769-
DataType promoted_type = DispatchFunc<DataTypeList<INFINI_ALL_TYPES>, DataTypeList<INFINI_ALL_TYPES>>(
770-
{dtype, a_dtype}, [=]<typename Tgrad, typename Ta>() { return DataTypeMap_v<WidestType_t<Tgrad, Ta>>; },
771-
"MACA UnaryBackward");
770+
DataType promoted_type = PromoteDataTypes(dtype, a_dtype);
772771

773772
auto grad_output_promoted
774773
= dtype == promoted_type ? grad_output : std::make_shared<Tensor>(grad_output->To(promoted_type));
@@ -795,9 +794,7 @@ std::shared_ptr<Tensor> BinaryForward(const std::shared_ptr<Tensor> &a, const st
795794
auto a_dtype = a->Dtype();
796795
auto b_dtype = b->Dtype();
797796

798-
DataType promoted_type = DispatchFunc<DataTypeList<INFINI_ALL_TYPES>, DataTypeList<INFINI_ALL_TYPES>>(
799-
{a_dtype, b_dtype}, [=]<typename Ta, typename Tb>() { return DataTypeMap_v<WidestType_t<Ta, Tb>>; },
800-
"MACA BinaryForward");
797+
DataType promoted_type = PromoteDataTypes(a_dtype, b_dtype);
801798

802799
auto a_promoted = a_dtype == promoted_type ? a : std::make_shared<Tensor>(a->To(promoted_type));
803800
auto b_promoted = b_dtype == promoted_type ? b : std::make_shared<Tensor>(b->To(promoted_type));
@@ -837,9 +834,7 @@ BinaryBackward(const std::shared_ptr<Tensor> &grad_output, const std::shared_ptr
837834
auto a_dtype = a_promoted ? a_promoted->Dtype() : dtype;
838835
auto b_dtype = b_promoted ? b_promoted->Dtype() : dtype;
839836
// Compute dtype determined by saved tensors (forward compute dtype), not grad_output
840-
DataType promoted_type = DispatchFunc<DataTypeList<INFINI_ALL_TYPES>, DataTypeList<INFINI_ALL_TYPES>>(
841-
{a_dtype, b_dtype}, [=]<typename Ta, typename Tb>() { return DataTypeMap_v<WidestType_t<Ta, Tb>>; },
842-
"MACA BinaryBackward");
837+
DataType promoted_type = PromoteDataTypes(a_dtype, b_dtype);
843838

844839
CHECK(a_num_elements >= b_num_elements && a_num_elements % b_num_elements == 0);
845840

@@ -867,26 +862,26 @@ BinaryBackward(const std::shared_ptr<Tensor> &grad_output, const std::shared_ptr
867862
switch (promoted_type) {
868863
DISPATCH_CASE(WRAP({
869864
if (needs_broadcast) {
870-
grad_a->Fill<float>(0.0f);
871-
grad_b->Fill<float>(0.0f);
865+
grad_a->Fill(0.0f);
866+
grad_b->Fill(0.0f);
872867
}
873868
LaunchBackward<256, float>(fn_a, fn_b, grad_a, grad_b, a_dims, b_dims, grad_output_promoted,
874869
a_promoted, b_promoted);
875870
}),
876871
DataType::kFLOAT32)
877872
DISPATCH_CASE(WRAP({
878873
if (needs_broadcast) {
879-
grad_a->Fill<__maca_bfloat16>(0);
880-
grad_b->Fill<__maca_bfloat16>(0);
874+
grad_a->Fill(0);
875+
grad_b->Fill(0);
881876
}
882877
LaunchBackward<256, __maca_bfloat16>(fn_a, fn_b, grad_a, grad_b, a_dims, b_dims,
883878
grad_output_promoted, a_promoted, b_promoted);
884879
}),
885880
DataType::kBFLOAT16)
886881
// FIXME(zbl): AtomicAdd does not support int64_t
887882
// DISPATCH_CASE(WRAP({
888-
// grad_a->Fill<int64_t>(0);
889-
// grad_b->Fill<int64_t>(0);
883+
// grad_a->Fill(0);
884+
// grad_b->Fill(0);
890885
// LaunchBackward<256, int64_t>(fn_a, fn_b, grad_a, grad_b, a_dims, b_dims, grad_output, a,
891886
// b);
892887
// }),

0 commit comments

Comments
 (0)