Skip to content

Commit 421f683

Browse files
committed
feat(maca): add MetaX MACA backend skeleton and minimal kernels
Introduce a MACA (MetaX 沐曦) backend plugged into the DeviceGuardImpl / kernel dispatcher framework, targeting the minimal kernel set needed to validate single-card fp32 training (e.g. mnist) end-to-end: - Build system: USE_MACA / USE_MCCL options, mxcc toolchain override, mxomp linkage under USE_OMP, .maca kernel library with -x maca, and backend-exclusive SRC filtering so non-target backends are not pulled in. - Device enum: add Device::DeviceType::kMACA (kCount bumped to 3), IsMACA(), and a three-way ToString() switch. - common/maca: MACA_CHECK / MCBLAS_CHECK / MCCL_CHECK macros and the kernel_helper.cuh template library (Cast/Neg/Sin/Pow/Add/Sub/Mul/Div/ Max/Min/Fma/fastAtomicAdd) plus a cub_compat.cuh shim pinning CubSumOp/ CubMaxOp/CubMinOp to the pre-2.8 CUB API that MACA ships. - core/runtime/maca: MacaStream / MacaEvent / MacaBlasHandle derived from core::Stream / Event / BlasHandle, and MacaGuardImpl mirroring CudaGuardImpl (mcInit(0) in ctor, call_once'd default stream/handle caches, full stream/event/sync/blas/memory surface). Mempool watermark hooks are stubs pending SDK verification. - datatype.h / tensor.cc / nn/init.cc: add USE_MACA branches to map kBFLOAT16 / kFLOAT16 to __maca_bfloat16 / __half, specialize the is_floating_point_ext / is_arithmetic_ext / LargerType traits, route Fill casts through float under real device backends to dodge the ambiguous __half(int) constructor on MACA, and wire Arange for bf16/fp16. - kernels/maca: mechanically port the minimal 5-kernel slice (elementwise, linear, fill, no_op, accumulate_grad) from their .cu counterparts, switching blas/stream acquisition to the new GetDeviceGuardImpl()->GetBlasHandle()/GetStream() idiom. The MCCL collective backend and the remaining 15 kernels (which are required for gpt2 / DDP) will land in a follow-up commit.
1 parent be8d5a8 commit 421f683

16 files changed

Lines changed: 2936 additions & 5 deletions

CMakeLists.txt

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,28 @@
11
cmake_minimum_required(VERSION 3.28)
22

3+
# Platforms
34
option(USE_CUDA "Support NVIDIA CUDA" OFF)
5+
option(USE_MACA "Support MetaX MACA" OFF)
6+
47
option(PROFILE_MODE "ENABLE PROFILE MODE" OFF)
58
option(USE_OMP "Use OpenMP as backend for Eigen" ON)
6-
option(USE_NCCL "Build project for distributed running" ON)
9+
option(USE_NCCL "Build project for distributed running on CUDA using NCCL" ON)
10+
option(USE_MCCL "Build project for distributed running on MACA using MCCL" ON)
11+
12+
# ------------------------------------------------------------------------------
13+
# MACA toolchain override (must happen before project())
14+
# ------------------------------------------------------------------------------
15+
# When targeting MetaX MACA, the C/C++ compiler must be mxcc so that .maca
16+
# sources and device code can be compiled by the MACA toolchain.
17+
if(USE_MACA)
18+
set(MACA_PATH $ENV{MACA_PATH})
19+
if(NOT MACA_PATH)
20+
message(FATAL_ERROR "USE_MACA=ON but environment variable MACA_PATH is not set. "
21+
"Please export MACA_PATH (e.g. /opt/maca) before configuring.")
22+
endif()
23+
set(CMAKE_C_COMPILER "${MACA_PATH}/mxgpu_llvm/bin/mxcc")
24+
set(CMAKE_CXX_COMPILER "${MACA_PATH}/mxgpu_llvm/bin/mxcc")
25+
endif()
726

827
project(infini_train VERSION 0.5.0 LANGUAGES CXX)
928

@@ -31,6 +50,22 @@ include_directories(${glog_SOURCE_DIR}/src)
3150
# eigen
3251
if(USE_OMP)
3352
find_package(OpenMP REQUIRED)
53+
54+
set(INFINI_OMP_LIBS OpenMP::OpenMP_CXX)
55+
56+
# Under MACA/mxcc, the host compiler is LLVM-based; link mxomp (iomp5) instead
57+
# of libgomp to stay ABI-compatible with the MACA toolchain.
58+
if(USE_MACA)
59+
find_library(INFINI_MACA_OMP_LIB
60+
NAMES omp iomp5
61+
HINTS
62+
"${MACA_PATH}/lib"
63+
"${MACA_PATH}/mxgpu_llvm/lib"
64+
"${MACA_PATH}/mxgpu_llvm/lib64"
65+
REQUIRED
66+
)
67+
set(INFINI_OMP_LIBS OpenMP::OpenMP_CXX ${INFINI_MACA_OMP_LIB})
68+
endif()
3469
endif()
3570
add_subdirectory(third_party/eigen)
3671
include_directories(${PROJECT_SOURCE_DIR}/third_party/eigen)
@@ -48,9 +83,25 @@ endif()
4883
# Framework core sources (*.cc), excluding cpu kernels (they are built separately)
4984
file(GLOB_RECURSE SRC ${PROJECT_SOURCE_DIR}/infini_train/src/*.cc)
5085
list(FILTER SRC EXCLUDE REGEX ".*kernels/cpu/.*")
86+
87+
# Exclude backend-specific runtime/ccl translation units when the corresponding
88+
# backend is disabled. This keeps each build self-contained and avoids pulling
89+
# in headers (e.g. <cuda_runtime.h> / <mcr/mc_runtime.h>) that aren't on the
90+
# include path.
91+
if(NOT USE_CUDA)
92+
list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/runtime/cuda/.*")
93+
list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/ccl/cuda/.*")
94+
endif()
95+
if(NOT USE_MACA)
96+
list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/runtime/maca/.*")
97+
list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/ccl/maca/.*")
98+
endif()
5199
if(NOT USE_NCCL)
52100
list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/ccl/cuda/.*")
53101
endif()
102+
if(NOT USE_MCCL)
103+
list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/ccl/maca/.*")
104+
endif()
54105

55106
# CPU kernels (*.cc)
56107
file(GLOB_RECURSE CPU_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/kernels/cpu/*.cc)
@@ -64,7 +115,7 @@ target_link_libraries(infini_train_cpu_kernels PUBLIC glog Eigen3::Eigen)
64115

65116
if(USE_OMP)
66117
add_compile_definitions(USE_OMP=1)
67-
target_link_libraries(infini_train_cpu_kernels PUBLIC OpenMP::OpenMP_CXX)
118+
target_link_libraries(infini_train_cpu_kernels PUBLIC ${INFINI_OMP_LIBS})
68119
endif()
69120

70121
# ------------------------------------------------------------------------------
@@ -103,6 +154,46 @@ if(USE_CUDA)
103154
endif()
104155
endif()
105156

157+
# ------------------------------------------------------------------------------
158+
# MACA kernels library (optional, MetaX backend)
159+
# ------------------------------------------------------------------------------
160+
161+
if(USE_MACA)
162+
add_compile_definitions(USE_MACA=1)
163+
164+
# ---- MACA SDK include / link paths ----
165+
include_directories("${MACA_PATH}/include")
166+
link_directories("${MACA_PATH}/lib")
167+
168+
# ---- MACA runtime / blas / (optional) mccl libraries ----
169+
find_library(MACA_RUNTIME_LIB NAMES mcruntime HINTS "${MACA_PATH}/lib" REQUIRED)
170+
find_library(MACA_DNN_LIB NAMES mcdnn HINTS "${MACA_PATH}/lib" REQUIRED)
171+
find_library(MACA_BLAS_LIB NAMES mcblas HINTS "${MACA_PATH}/lib" REQUIRED)
172+
173+
# ---- Collect .maca kernel sources and build as a CXX static lib with -x maca ----
174+
file(GLOB_RECURSE MACA_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/kernels/maca/*.maca)
175+
set_source_files_properties(${MACA_KERNELS} PROPERTIES
176+
LANGUAGE CXX
177+
COMPILE_OPTIONS "-x;maca"
178+
)
179+
180+
add_library(infini_train_maca_kernels STATIC ${MACA_KERNELS})
181+
target_link_libraries(infini_train_maca_kernels
182+
PUBLIC
183+
glog
184+
${MACA_RUNTIME_LIB}
185+
${MACA_DNN_LIB}
186+
${MACA_BLAS_LIB}
187+
)
188+
189+
if(USE_MCCL)
190+
message(STATUS "Add USE_MCCL, use MCCL with MACA")
191+
find_library(MACA_COMM_LIB NAMES mccl HINTS "${MACA_PATH}/lib" REQUIRED)
192+
add_compile_definitions(USE_MCCL=1)
193+
target_link_libraries(infini_train_maca_kernels PUBLIC ${MACA_COMM_LIB})
194+
endif()
195+
endif()
196+
106197
# ------------------------------------------------------------------------------
107198
# Main framework library
108199
# ------------------------------------------------------------------------------
@@ -133,6 +224,22 @@ if(USE_CUDA)
133224
endif()
134225
endif()
135226

227+
if(USE_MACA)
228+
# infini_train contains MACA runtime wrappers (maca_guard_impl.cc / maca_runtime_common.cc /
229+
# mccl_impl.cc) which reference mcruntime / mcblas / mccl symbols directly at final link.
230+
target_link_libraries(infini_train
231+
PUBLIC
232+
infini_train_maca_kernels
233+
${MACA_RUNTIME_LIB}
234+
${MACA_DNN_LIB}
235+
${MACA_BLAS_LIB}
236+
)
237+
238+
if(USE_MCCL)
239+
target_link_libraries(infini_train PUBLIC ${MACA_COMM_LIB})
240+
endif()
241+
endif()
242+
136243
# ------------------------------------------------------------------------------
137244
# Helper: link libraries in a group to fix static lib one-pass resolution
138245
# (THIS is what fixes "undefined reference" from cuda_kernels -> core symbols)
@@ -148,6 +255,16 @@ function(link_infini_train_exe target_name)
148255
"-Wl,--no-whole-archive"
149256
"-Wl,--end-group"
150257
)
258+
elseif(USE_MACA)
259+
target_link_libraries(${target_name} PRIVATE
260+
"-Wl,--start-group"
261+
"-Wl,--whole-archive"
262+
infini_train
263+
infini_train_cpu_kernels
264+
infini_train_maca_kernels
265+
"-Wl,--no-whole-archive"
266+
"-Wl,--end-group"
267+
)
151268
else()
152269
target_link_libraries(${target_name} PRIVATE
153270
"-Wl,--start-group"

infini_train/include/autocast.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ inline const std::unordered_map<std::string_view, CastPolicy> kOpCastPolicyMap =
8888
// Default autocast data types for each device type
8989
inline constexpr std::array<DataType, static_cast<size_t>(Device::DeviceType::kCount)> kDeviceDefaultDtype = {
9090
DataType::kBFLOAT16, // CPU
91-
DataType::kFLOAT16, // CUDA.
91+
DataType::kFLOAT16, // CUDA
92+
DataType::kFLOAT16, // MACA
9293
};
9394

9495
// Thread-local context to track autocast state
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#pragma once
2+
3+
#include <mcr/mc_runtime.h>
4+
#include <mcr/mc_runtime_api.h>
5+
#include <mcblas/mcblas.h>
6+
7+
#ifdef USE_MCCL
8+
#include <mccl.h>
9+
#endif
10+
11+
#include "glog/logging.h"
12+
13+
namespace infini_train::common::maca {
14+
15+
// Common MACA Macros
16+
#define MACA_CHECK(call) \
17+
do { \
18+
mcError_t status = call; \
19+
if (status != mcSuccess) { \
20+
LOG(FATAL) << "MACA Error: " << mcGetErrorString(status) << " at " << __FILE__ << ":" << __LINE__; \
21+
} \
22+
} while (0)
23+
24+
#define MCBLAS_CHECK(call) \
25+
do { \
26+
mcblasStatus_t status = call; \
27+
if (status != MCBLAS_STATUS_SUCCESS) { \
28+
LOG(FATAL) << "MCBLAS Error: " << mcblasGetStatusString(status) << " at " << __FILE__ << ":" << __LINE__; \
29+
} \
30+
} while (0)
31+
32+
#ifdef USE_MCCL
33+
#define MCCL_CHECK(expr) \
34+
do { \
35+
mcclResult_t _status = (expr); \
36+
if (_status != mcclSuccess) { \
37+
LOG(FATAL) << "MCCL error: " << mcclGetErrorString(_status) << " at " << __FILE__ << ":" << __LINE__ \
38+
<< " (" << #expr << ")"; \
39+
} \
40+
} while (0)
41+
#endif
42+
43+
} // namespace infini_train::common::maca
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#pragma once
2+
3+
#include <cub/cub.cuh>
4+
5+
namespace infini_train::kernels::maca {
6+
7+
// MACA ships a CUB compatible with the pre-2.8 API (cub::Sum/Max/Min).
8+
// Mirror the CUDA cub_compat.cuh aliases so that kernel code can refer to
9+
// CubSumOp / CubMaxOp / CubMinOp uniformly across backends.
10+
using CubSumOp = cub::Sum;
11+
using CubMaxOp = cub::Max;
12+
using CubMinOp = cub::Min;
13+
14+
} // namespace infini_train::kernels::maca

0 commit comments

Comments
 (0)