From 38e033001257515910e0762b4f98463b9e1f55ec Mon Sep 17 00:00:00 2001 From: zhangyue Date: Wed, 8 Apr 2026 10:52:12 +0800 Subject: [PATCH 1/7] =?UTF-8?q?feat(ascend):=20add=20Ascend=20framework=20?= =?UTF-8?q?layer=20=E2=80=94=20runtime,=20type=20mapping,=20build=20integr?= =?UTF-8?q?ation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add Ascend platform scaffolding: - `device_.h`: `DeviceEnabled` specialization - `data_type_.h`: `toAclDtype()`, `isIntegerDtype()` - `common.h`: `buildAclTensor()` with optional transpose - `workspace_pool_.h`: stream-keyed workspace allocator - `runtime_.h`: `Runtime` (Malloc, Free, Memcpy, Memset) - 5 new operator base classes (AddRmsNorm, FlashAttention, Matmul, ReshapeAndCache, RotaryEmbedding) Integrate into CMake build system, Python binding generation (stream + optional tensor support), and examples runtime API. --- .gitignore | 1 + CMakeLists.txt | 27 ++++++++- examples/runtime_api.h | 5 ++ scripts/generate_wrappers.py | 84 +++++++++++++++++++------- src/CMakeLists.txt | 52 +++++++++++++++- src/ascend/common.h | 58 ++++++++++++++++++ src/ascend/data_type_.h | 50 ++++++++++++++++ src/ascend/device_.h | 16 +++++ src/ascend/runtime_.h | 39 ++++++++++++ src/ascend/workspace_pool_.h | 53 +++++++++++++++++ src/base/add_rms_norm.h | 51 ++++++++++++++++ src/base/flash_attention.h | 112 +++++++++++++++++++++++++++++++++++ src/base/matmul.h | 41 +++++++++++++ src/base/reshape_and_cache.h | 73 +++++++++++++++++++++++ src/base/rotary_embedding.h | 80 +++++++++++++++++++++++++ src/operator.h | 9 +-- src/pybind11_utils.h | 6 ++ 17 files changed, 726 insertions(+), 31 deletions(-) create mode 100644 src/ascend/common.h create mode 100644 src/ascend/data_type_.h create mode 100644 src/ascend/device_.h create mode 100644 src/ascend/runtime_.h create mode 100644 src/ascend/workspace_pool_.h create mode 100644 src/base/add_rms_norm.h create mode 100644 src/base/flash_attention.h create mode 100644 src/base/matmul.h create mode 100644 src/base/reshape_and_cache.h create mode 100644 src/base/rotary_embedding.h diff --git a/.gitignore b/.gitignore index 2effaff..3ca9c90 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ # Generated files build/ generated/ +.worktrees/ # Prerequisites *.d diff --git a/CMakeLists.txt b/CMakeLists.txt index b9e2deb..7f4c5cb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,7 @@ option(WITH_ILUVATAR "Enable Iluvatar GPU backend" OFF) option(WITH_METAX "Enable MetaX backend" OFF) option(WITH_CAMBRICON "Enable Cambricon backend" OFF) option(WITH_MOORE "Enable Moore backend" OFF) +option(WITH_ASCEND "Enable Ascend backend" OFF) option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF) option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF) @@ -71,20 +72,25 @@ if(AUTO_DETECT_DEVICES) set(WITH_MOORE OFF) set(WITH_MOORE OFF CACHE BOOL "Enable Moore backend" FORCE) endif() + + if(DEFINED ENV{ASCEND_HOME_PATH} OR EXISTS "/dev/davinci0") + set(WITH_ASCEND ON) + message(STATUS "Auto-detected Ascend environment.") + endif() endif() include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) # Only one CUDA-like GPU backend can be enabled at a time. set(_gpu_backend_count 0) -foreach(_gpu_backend WITH_NVIDIA WITH_ILUVATAR WITH_METAX WITH_MOORE) +foreach(_gpu_backend WITH_NVIDIA WITH_ILUVATAR WITH_METAX WITH_MOORE WITH_ASCEND) if(${_gpu_backend}) math(EXPR _gpu_backend_count "${_gpu_backend_count} + 1") endif() endforeach() if(_gpu_backend_count GREATER 1) - message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_METAX`, and `WITH_MOORE` are mutually exclusive. Build one GPU backend at a time.") + message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_METAX`, `WITH_MOORE`, and `WITH_ASCEND` are mutually exclusive. Build one GPU backend at a time.") endif() if(WITH_NVIDIA) @@ -178,8 +184,23 @@ if(WITH_CAMBRICON) find_library(CAMBRICON_PAPI_LIB NAMES cnpapi HINTS "${NEUWARE_HOME}/lib64" REQUIRED) endif() +if(WITH_ASCEND) + add_compile_definitions(WITH_ASCEND=1) + if(NOT DEFINED ASCEND_HOME) + if(DEFINED ENV{ASCEND_HOME_PATH} AND NOT "$ENV{ASCEND_HOME_PATH}" STREQUAL "") + set(ASCEND_HOME "$ENV{ASCEND_HOME_PATH}" CACHE PATH "Ascend toolkit root") + else() + set(ASCEND_HOME "/usr/local/Ascend/ascend-toolkit/latest" CACHE PATH "Ascend toolkit root") + endif() + endif() + if(NOT EXISTS "${ASCEND_HOME}") + message(FATAL_ERROR "`WITH_ASCEND` is ON but `${ASCEND_HOME}` was not found. Set ASCEND_HOME_PATH.") + endif() + message(STATUS "Using Ascend from `${ASCEND_HOME}`.") +endif() + # If all other platforms are not enabled, CPU is enabled by default. -if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE AND NOT WITH_CAMBRICON) +if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE AND NOT WITH_CAMBRICON AND NOT WITH_ASCEND) add_compile_definitions(WITH_CPU=1) endif() diff --git a/examples/runtime_api.h b/examples/runtime_api.h index 4c7469f..8b63153 100644 --- a/examples/runtime_api.h +++ b/examples/runtime_api.h @@ -19,6 +19,9 @@ #elif WITH_MOORE #include "moore/gemm/mublas.h" #include "moore/runtime_.h" +#elif WITH_ASCEND +#include "ascend/gemm/kernel.h" +#include "ascend/runtime_.h" #elif WITH_CPU #include "cpu/gemm/gemm.h" #include "cpu/runtime_.h" @@ -38,6 +41,8 @@ using DefaultRuntimeUtils = Runtime; using DefaultRuntimeUtils = Runtime; #elif WITH_MOORE using DefaultRuntimeUtils = Runtime; +#elif WITH_ASCEND +using DefaultRuntimeUtils = Runtime; #elif WITH_CPU using DefaultRuntimeUtils = Runtime; #endif diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 5aa8896..fc8f1bf 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -91,26 +91,56 @@ def __init__(self, name, constructors, calls): self.calls = calls +def _find_optional_tensor_params(op_name): + """Return a set of parameter names declared as `std::optional` in + the base header. libclang resolves the type to ``int`` when the STL + headers are not fully available, so we fall back to a regex scan of the + source text. + """ + import re + + source = (_BASE_DIR / f"{op_name}.h").read_text() + return set(re.findall(r"std::optional\s+(\w+)", source)) + + def _generate_pybind11(operator): + optional_tensor_params = _find_optional_tensor_params(operator.name) + + def _is_optional_tensor(arg): + if arg.spelling in optional_tensor_params: + return True + return "std::optional" in arg.type.spelling and "Tensor" in arg.type.spelling + def _generate_params(node): - return ( - ", ".join( - f"{arg.type.spelling} {arg.spelling}" - for arg in node.get_arguments() - if arg.spelling != "stream" - ) - .replace("const Tensor", "py::object") - .replace("Tensor", "py::object") - ) + parts = [] + for arg in node.get_arguments(): + if arg.spelling == "stream": + continue + if _is_optional_tensor(arg): + parts.append(f"std::optional {arg.spelling}") + else: + param = ( + arg.type.spelling + .replace("const Tensor", "py::object") + .replace("Tensor", "py::object") + ) + parts.append(f"{param} {arg.spelling}") + return ", ".join(parts) def _generate_arguments(node): - return ", ".join( - f"TensorFromPybind11Handle({arg.spelling})" - if "Tensor" in arg.type.spelling - else arg.spelling - for arg in node.get_arguments() - if arg.spelling != "stream" - ) + args = [] + for arg in node.get_arguments(): + if arg.spelling == "stream": + continue + if _is_optional_tensor(arg): + args.append( + f"OptionalTensorFromPybind11Handle({arg.spelling})" + ) + elif "Tensor" in arg.type.spelling: + args.append(f"TensorFromPybind11Handle({arg.spelling})") + else: + args.append(arg.spelling) + return ", ".join(args) op_name = operator.name @@ -134,18 +164,24 @@ def _generate_call(op_name, call, method=True): if not method: params = ( - f"{call_params}, std::size_t implementation_index" + f"{call_params}, std::size_t implementation_index, std::uintptr_t stream" if call_params - else "std::size_t implementation_index" + else "std::size_t implementation_index, std::uintptr_t stream" ) py_args = _generate_py_args(call) py_args_str = f"{py_args}, " if py_args else "" - return f""" m.def("{op_name}", []({params}) {{ - Config config; - config.set_implementation_index(implementation_index); - return Self::call({{}}, config, {call_args}); - }}, {py_args_str}py::kw_only(), py::arg("implementation_index") = 0);""" + return ( + f' m.def("{op_name}", []({params}) {{\n' + f" Config config;\n" + f" config.set_implementation_index(implementation_index);\n" + f" Handle handle;\n" + f" if (stream) {{\n" + f" handle.set_stream(reinterpret_cast(stream));\n" + f" }}\n" + f" return Self::call(handle, config, {call_args});\n" + f" }}, {py_args_str}py::kw_only(), py::arg(\"implementation_index\") = 0, py::arg(\"stream\") = 0);" + ) return f""" .def("__call__", [](const Self& self, {call_params}) {{ return static_cast&>(self)({call_args}); @@ -169,6 +205,8 @@ def _generate_call(op_name, call, method=True): #include "base/{op_name}.h" #include "config.h" +#include "handle.h" +#include "operator.h" #include "pybind11_utils.h" namespace py = pybind11; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0b56341..17abb8c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -40,7 +40,7 @@ if(WITH_NVIDIA) target_sources(infiniops PRIVATE ${NVIDIA_SOURCES}) find_package(CUDAToolkit REQUIRED) - target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cuda_driver) + target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cuda_driver) list(APPEND DEVICE_LIST "nvidia") set_target_properties(infiniops PROPERTIES @@ -172,10 +172,60 @@ if(WITH_CAMBRICON) list(APPEND DEVICE_LIST "cambricon") endif() +if(WITH_ASCEND) + # ASCEND_HOME is set by the top-level CMakeLists.txt. + file(GLOB_RECURSE ASCEND_SOURCES CONFIGURE_DEPENDS + "ascend/*.cc" + "ascend/*.cpp" + ) + # Exclude kernel_impl.cpp — AscendC device code, not compiled by the host C++ compiler. + list(FILTER ASCEND_SOURCES EXCLUDE REGEX ".*kernel_impl\\.cpp$") + + target_compile_definitions(infiniops PUBLIC WITH_ASCEND=1) + target_sources(infiniops PRIVATE ${ASCEND_SOURCES}) + + # Resolve the driver lib dir two levels above the toolkit root. + get_filename_component(ASCEND_ROOT "${ASCEND_HOME}/../.." ABSOLUTE) + + # Prefer the real driver HAL; fall back to the toolkit stub for build-only + # environments (e.g., Docker CI images without hardware drivers installed). + # CANN <= 8.0: stub at runtime/lib64/stub/; CANN >= 8.5: devlib/-linux/devlib/. + set(ASCEND_HAL_REAL "${ASCEND_ROOT}/driver/lib64/driver/libascend_hal.so") + set(ASCEND_HAL_STUB "${ASCEND_HOME}/runtime/lib64/stub/libascend_hal.so") + set(ASCEND_HAL_DEVLIB "${ASCEND_HOME}/${CMAKE_SYSTEM_PROCESSOR}-linux/devlib/libascend_hal.so") + if(EXISTS "${ASCEND_HAL_REAL}") + set(ASCEND_HAL_LIB "${ASCEND_HAL_REAL}") + elseif(EXISTS "${ASCEND_HAL_STUB}") + set(ASCEND_HAL_LIB "${ASCEND_HAL_STUB}") + message(STATUS "ascend_hal: driver not found, using stub for linking") + elseif(EXISTS "${ASCEND_HAL_DEVLIB}") + set(ASCEND_HAL_LIB "${ASCEND_HAL_DEVLIB}") + message(STATUS "ascend_hal: driver not found, using devlib for linking") + else() + message(FATAL_ERROR "libascend_hal.so not found (tried ${ASCEND_HAL_REAL}, ${ASCEND_HAL_STUB}, and ${ASCEND_HAL_DEVLIB})") + endif() + + target_include_directories(infiniops PUBLIC + "${ASCEND_HOME}/include" + "${ASCEND_HOME}/include/aclnn" + "${ASCEND_HOME}/include/aclnnop") + target_link_libraries(infiniops PUBLIC + "${ASCEND_HOME}/lib64/libascendcl.so" + "${ASCEND_HOME}/lib64/libnnopbase.so" + "${ASCEND_HOME}/lib64/libopapi.so" + "${ASCEND_HAL_LIB}") + + list(APPEND DEVICE_LIST "ascend") +endif() + target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) if(GENERATE_PYTHON_BINDINGS) find_package(Python COMPONENTS Interpreter REQUIRED) + # Always regenerate bindings so the included kernel headers match the + # active device list. Stale generated files (e.g., committed for one + # platform) would omit specializations for other enabled backends, + # causing link-time or runtime failures. execute_process( COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py --devices ${DEVICE_LIST} WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} diff --git a/src/ascend/common.h b/src/ascend/common.h new file mode 100644 index 0000000..3dbeeae --- /dev/null +++ b/src/ascend/common.h @@ -0,0 +1,58 @@ +#ifndef INFINI_OPS_ASCEND_COMMON_H_ +#define INFINI_OPS_ASCEND_COMMON_H_ + +#include +#include + +#include "acl/acl.h" +#include "aclnn/acl_meta.h" +#include "ascend/data_type_.h" +#include "tensor.h" + +namespace infini::ops::ascend { + +// Build an aclTensor descriptor from an InfiniOps Tensor. +// +// When `transpose_last2` is true the last two dimensions are swapped in the +// descriptor (shape and strides) without copying data. This is used by GEMM +// and Matmul to express a transpose via the view. +inline aclTensor* buildAclTensor(const Tensor& t, + bool transpose_last2 = false) { + std::vector shape(t.shape().begin(), t.shape().end()); + std::vector strides(t.strides().begin(), t.strides().end()); + + if (transpose_last2 && shape.size() >= 2) { + auto n = shape.size(); + std::swap(shape[n - 2], shape[n - 1]); + std::swap(strides[n - 2], strides[n - 1]); + } + + // Compute the minimum physical storage needed for this strided view. + // For contiguous tensors this equals numel(); for non-contiguous (gapped) + // tensors it may be larger; for broadcast (stride-0) tensors it may be + // smaller. Passing the view shape as the storage shape causes + // "ViewShape overlap" errors in ACLNN for non-contiguous inputs. + int64_t storage_elems = 1; + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] == 0) { storage_elems = 0; break; } + if (strides[i] > 0 && shape[i] > 1) { + storage_elems += static_cast(shape[i] - 1) * strides[i]; + } + } + std::vector storage_shape = {storage_elems}; + + return aclCreateTensor( + shape.data(), + static_cast(shape.size()), + toAclDtype(t.dtype()), + strides.data(), + /*storageOffset=*/0, + ACL_FORMAT_ND, + storage_shape.data(), + static_cast(storage_shape.size()), + const_cast(t.data())); +} + +} // namespace infini::ops::ascend + +#endif diff --git a/src/ascend/data_type_.h b/src/ascend/data_type_.h new file mode 100644 index 0000000..0574b23 --- /dev/null +++ b/src/ascend/data_type_.h @@ -0,0 +1,50 @@ +#ifndef INFINI_OPS_ASCEND_DATA_TYPE__H_ +#define INFINI_OPS_ASCEND_DATA_TYPE__H_ + +#include + +#include "acl/acl.h" +#include "ascend/device_.h" +#include "data_type.h" + +namespace infini::ops::ascend { + +inline aclDataType toAclDtype(DataType dt) { + switch (dt) { + case DataType::kFloat16: return ACL_FLOAT16; + case DataType::kBFloat16: return ACL_BF16; + case DataType::kFloat32: return ACL_FLOAT; + case DataType::kInt8: return ACL_INT8; + case DataType::kInt16: return ACL_INT16; + case DataType::kInt32: return ACL_INT32; + case DataType::kInt64: return ACL_INT64; + case DataType::kUInt8: return ACL_UINT8; + case DataType::kUInt16: return ACL_UINT16; + case DataType::kUInt32: return ACL_UINT32; + case DataType::kUInt64: return ACL_UINT64; + default: + assert(false && "unsupported dtype for Ascend backend"); + return ACL_DT_UNDEFINED; + } +} + +// Returns true for integer (signed or unsigned) DataType values. +inline bool isIntegerDtype(DataType dt) { + switch (dt) { + case DataType::kInt8: + case DataType::kInt16: + case DataType::kInt32: + case DataType::kInt64: + case DataType::kUInt8: + case DataType::kUInt16: + case DataType::kUInt32: + case DataType::kUInt64: + return true; + default: + return false; + } +} + +} // namespace infini::ops::ascend + +#endif diff --git a/src/ascend/device_.h b/src/ascend/device_.h new file mode 100644 index 0000000..b4ec934 --- /dev/null +++ b/src/ascend/device_.h @@ -0,0 +1,16 @@ +#ifndef INFINI_OPS_ASCEND_DEVICE__H_ +#define INFINI_OPS_ASCEND_DEVICE__H_ + +// NOTE: Cannot use `#include "device.h"` here — GCC resolves quoted includes +// relative to the current file first, and `src/ascend/` used to contain a +// `device.h`. Use `data_type.h` which transitively pulls in `src/device.h`. +#include "data_type.h" + +namespace infini::ops { + +template <> +struct DeviceEnabled : std::true_type {}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/runtime_.h b/src/ascend/runtime_.h new file mode 100644 index 0000000..2918d5e --- /dev/null +++ b/src/ascend/runtime_.h @@ -0,0 +1,39 @@ +#ifndef INFINI_OPS_ASCEND_RUNTIME__H_ +#define INFINI_OPS_ASCEND_RUNTIME__H_ + +// clang-format off +#include "acl/acl.h" +// clang-format on + +#include "ascend/device_.h" +#include "runtime.h" + +namespace infini::ops { + +template <> +struct Runtime + : DeviceRuntime> { + using Stream = aclrtStream; + + static constexpr Device::Type kDeviceType = Device::Type::kAscend; + + static constexpr auto Malloc = [](void** ptr, size_t size) { + return aclrtMalloc(ptr, size, ACL_MEM_MALLOC_HUGE_FIRST); + }; + + static constexpr auto Free = aclrtFree; + + static constexpr auto Memcpy = aclrtMemcpy; + + static constexpr auto MemcpyHostToDevice = ACL_MEMCPY_HOST_TO_DEVICE; + + static constexpr auto MemcpyDeviceToHost = ACL_MEMCPY_DEVICE_TO_HOST; + + static constexpr auto Memset = aclrtMemset; +}; + +static_assert(Runtime::Validate()); + +} // namespace infini::ops + +#endif diff --git a/src/ascend/workspace_pool_.h b/src/ascend/workspace_pool_.h new file mode 100644 index 0000000..bd305fe --- /dev/null +++ b/src/ascend/workspace_pool_.h @@ -0,0 +1,53 @@ +#ifndef INFINI_OPS_ASCEND_WORKSPACE_POOL__H_ +#define INFINI_OPS_ASCEND_WORKSPACE_POOL__H_ + +#include +#include +#include + +#include "acl/acl.h" + +namespace infini::ops::ascend { + +struct WorkspaceArena { + void* buf = nullptr; + uint64_t capacity = 0; +}; + +class WorkspacePool { + public: + WorkspaceArena& ensure(aclrtStream stream, uint64_t needed) { + std::lock_guard lock(mutex_); + auto& arena = arenas_[stream]; + if (needed <= arena.capacity) return arena; + if (arena.capacity > 0) { + aclrtSynchronizeStream(stream); + aclrtFree(arena.buf); + } + if (needed > 0) { + aclrtMalloc(&arena.buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY); + } + arena.capacity = needed; + return arena; + } + + ~WorkspacePool() { + for (auto& [stream, arena] : arenas_) { + if (arena.capacity > 0) aclrtFree(arena.buf); + } + } + + private: + std::unordered_map arenas_; + + std::mutex mutex_; +}; + +inline WorkspacePool& workspacePool() { + static WorkspacePool pool; + return pool; +} + +} // namespace infini::ops::ascend + +#endif diff --git a/src/base/add_rms_norm.h b/src/base/add_rms_norm.h new file mode 100644 index 0000000..b8315af --- /dev/null +++ b/src/base/add_rms_norm.h @@ -0,0 +1,51 @@ +#ifndef INFINI_OPS_BASE_ADD_RMS_NORM_H_ +#define INFINI_OPS_BASE_ADD_RMS_NORM_H_ + +#include +#include + +#include "operator.h" +#include "tensor.h" + +namespace infini::ops { + +class AddRmsNorm : public Operator { + public: + AddRmsNorm(const Tensor x1, const Tensor x2, const Tensor gamma, float eps, + Tensor y_out, Tensor x_out) + : input_shape_{x1.shape()}, + eps_{eps}, + dim_{x1.size(-1)}, + ndim_{x1.ndim()}, + batch_size_{ndim_ == 2 ? x1.size(-2) : x1.size(-3)}, + nhead_{ndim_ == 2 ? 1 : x1.size(-2)}, + rstd_shape_{static_cast(batch_size_), + static_cast(nhead_)} { + assert(x1.dtype() == x2.dtype()); + assert(x1.dtype() == y_out.dtype()); + assert(x1.dtype() == x_out.dtype()); + } + + virtual void operator()(const Tensor x1, const Tensor x2, + const Tensor gamma, float eps, Tensor y_out, + Tensor x_out) const = 0; + + protected: + Tensor::Shape input_shape_; + + float eps_{1e-6f}; + + Tensor::Size dim_{0}; + + Tensor::Size ndim_{0}; + + Tensor::Size batch_size_{0}; + + Tensor::Size nhead_{1}; + + std::vector rstd_shape_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/flash_attention.h b/src/base/flash_attention.h new file mode 100644 index 0000000..7820e55 --- /dev/null +++ b/src/base/flash_attention.h @@ -0,0 +1,112 @@ +#ifndef INFINI_OPS_BASE_FLASH_ATTENTION_H_ +#define INFINI_OPS_BASE_FLASH_ATTENTION_H_ + +#include +#include +#include + +#include "operator.h" + +namespace infini::ops { + +class FlashAttention : public Operator { + public: + FlashAttention( + const Tensor query, const Tensor key, const Tensor value, + std::optional cu_seqlens_q, + std::optional cu_seqlens_kv, + std::optional block_table, + int64_t num_heads, int64_t num_kv_heads, int64_t head_size, + double scale, + bool causal, + int64_t window_left, + int64_t window_right, + int64_t block_size, + Tensor output) + : num_tokens_{query.size(0)}, + num_heads_{num_heads}, + num_kv_heads_{num_kv_heads}, + head_size_{head_size}, + scale_{scale}, + causal_{causal}, + window_left_{window_left}, + window_right_{window_right}, + block_size_{block_size}, + dtype_{query.dtype()}, + query_shape_{query.shape()}, + key_shape_{key.shape()}, + value_shape_{value.shape()}, + output_shape_{output.shape()}, + query_strides_{query.strides()}, + key_strides_{key.strides()}, + value_strides_{value.strides()}, + output_strides_{output.strides()}, + has_cu_seqlens_q_{cu_seqlens_q.has_value()}, + has_cu_seqlens_kv_{cu_seqlens_kv.has_value()}, + has_block_table_{block_table.has_value()} { + assert(num_heads % num_kv_heads == 0 && + "`FlashAttention` requires num_heads divisible by num_kv_heads"); + assert(query.ndim() == 3 && + "`FlashAttention` requires query to be 3D [T, N, D]"); + } + + virtual void operator()( + const Tensor query, const Tensor key, const Tensor value, + std::optional cu_seqlens_q, + std::optional cu_seqlens_kv, + std::optional block_table, + int64_t num_heads, int64_t num_kv_heads, int64_t head_size, + double scale, + bool causal, + int64_t window_left, + int64_t window_right, + int64_t block_size, + Tensor output) const = 0; + + protected: + Tensor::Size num_tokens_{0}; + + int64_t num_heads_{0}; + + int64_t num_kv_heads_{0}; + + int64_t head_size_{0}; + + double scale_{0.0}; + + bool causal_{false}; + + int64_t window_left_{-1}; + + int64_t window_right_{-1}; + + int64_t block_size_{0}; + + const DataType dtype_; + + Tensor::Shape query_shape_; + + Tensor::Shape key_shape_; + + Tensor::Shape value_shape_; + + Tensor::Shape output_shape_; + + Tensor::Strides query_strides_; + + Tensor::Strides key_strides_; + + Tensor::Strides value_strides_; + + Tensor::Strides output_strides_; + + bool has_cu_seqlens_q_{false}; + + bool has_cu_seqlens_kv_{false}; + + bool has_block_table_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/matmul.h b/src/base/matmul.h new file mode 100644 index 0000000..e988aa1 --- /dev/null +++ b/src/base/matmul.h @@ -0,0 +1,41 @@ +#ifndef INFINI_OPS_BASE_MATMUL_H_ +#define INFINI_OPS_BASE_MATMUL_H_ + +#include "operator.h" +#include "tensor.h" + +namespace infini::ops { + +class Matmul : public Operator { + public: + // trans_a / trans_b: if true, transpose the last two dims of a / b before + // multiplying. These are constructor parameters so the CacheKey encodes + // the transposition and distinct descriptors are cached for each combination. + Matmul(const Tensor a, const Tensor b, Tensor c, + bool trans_a, bool trans_b) + : a_shape_{a.shape()}, + b_shape_{b.shape()}, + c_shape_{c.shape()}, + trans_a_{trans_a}, + trans_b_{trans_b} { + assert(a.dtype() == b.dtype()); + } + + virtual void operator()(const Tensor a, const Tensor b, Tensor c, + bool trans_a, bool trans_b) const = 0; + + protected: + Tensor::Shape a_shape_; + + Tensor::Shape b_shape_; + + Tensor::Shape c_shape_; + + bool trans_a_{false}; + + bool trans_b_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/reshape_and_cache.h b/src/base/reshape_and_cache.h new file mode 100644 index 0000000..a53caca --- /dev/null +++ b/src/base/reshape_and_cache.h @@ -0,0 +1,73 @@ +#ifndef INFINI_OPS_BASE_RESHAPE_AND_CACHE_H_ +#define INFINI_OPS_BASE_RESHAPE_AND_CACHE_H_ + +#include +#include + +#include "operator.h" + +namespace infini::ops { + +class ReshapeAndCache : public Operator { + public: + ReshapeAndCache( + const Tensor key, const Tensor value, + const Tensor kv_cache, const Tensor slot_mapping, + Tensor kv_cache_out) + : num_tokens_{key.size(0)}, + num_kv_heads_{key.size(1)}, + head_size_{key.size(2)}, + block_size_{kv_cache.size(2)}, + key_shape_{key.shape()}, + value_shape_{value.shape()}, + kv_cache_shape_{kv_cache.shape()}, + slot_mapping_shape_{slot_mapping.shape()}, + key_strides_{key.strides()}, + value_strides_{value.strides()}, + kv_cache_strides_{kv_cache.strides()}, + slot_mapping_strides_{slot_mapping.strides()}, + kv_cache_out_strides_{kv_cache_out.strides()} { + assert(key.shape() == value.shape() && + "`ReshapeAndCache` requires key and value same shape"); + assert(kv_cache.ndim() == 5 && + "`ReshapeAndCache` requires kv_cache to be 5D [2, num_blocks, block_size, num_kv_heads, head_size]"); + assert(slot_mapping.ndim() == 1 && + "`ReshapeAndCache` requires slot_mapping to be 1D"); + } + + virtual void operator()( + const Tensor key, const Tensor value, + const Tensor kv_cache, const Tensor slot_mapping, + Tensor kv_cache_out) const = 0; + + protected: + Tensor::Size num_tokens_{0}; + + Tensor::Size num_kv_heads_{0}; + + Tensor::Size head_size_{0}; + + Tensor::Size block_size_{0}; + + Tensor::Shape key_shape_; + + Tensor::Shape value_shape_; + + Tensor::Shape kv_cache_shape_; + + Tensor::Shape slot_mapping_shape_; + + Tensor::Strides key_strides_; + + Tensor::Strides value_strides_; + + Tensor::Strides kv_cache_strides_; + + Tensor::Strides slot_mapping_strides_; + + Tensor::Strides kv_cache_out_strides_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/rotary_embedding.h b/src/base/rotary_embedding.h new file mode 100644 index 0000000..a38b20e --- /dev/null +++ b/src/base/rotary_embedding.h @@ -0,0 +1,80 @@ +#ifndef INFINI_OPS_BASE_ROTARY_EMBEDDING_H_ +#define INFINI_OPS_BASE_ROTARY_EMBEDDING_H_ + +#include +#include + +#include "operator.h" + +namespace infini::ops { + +class RotaryEmbedding : public Operator { + public: + RotaryEmbedding(const Tensor positions, const Tensor query, const Tensor key, + const Tensor cos_sin_cache, int64_t head_size, + int64_t rotary_dim, bool is_neox_style, Tensor query_out, + Tensor key_out) + : num_tokens_{query.size(0)}, + num_heads_{query.size(1)}, + num_kv_heads_{key.size(1)}, + head_size_{head_size}, + rotary_dim_{rotary_dim}, + is_neox_style_{is_neox_style}, + query_shape_{query.shape()}, + key_shape_{key.shape()}, + cos_sin_cache_shape_{cos_sin_cache.shape()}, + query_out_shape_{query_out.shape()}, + key_out_shape_{key_out.shape()}, + query_strides_{query.strides()}, + key_strides_{key.strides()}, + query_out_strides_{query_out.strides()}, + key_out_strides_{key_out.strides()} { + assert(query.ndim() == 3 && + "`RotaryEmbedding` requires query to be 3D [T, N, D]"); + assert(key.ndim() == 3 && + "`RotaryEmbedding` requires key to be 3D [T, N_kv, D]"); + assert(rotary_dim <= head_size && + "`RotaryEmbedding` requires rotary_dim <= head_size"); + } + + virtual void operator()(const Tensor positions, const Tensor query, + const Tensor key, const Tensor cos_sin_cache, + int64_t head_size, int64_t rotary_dim, + bool is_neox_style, Tensor query_out, + Tensor key_out) const = 0; + + protected: + Tensor::Size num_tokens_{0}; + + int64_t num_heads_{0}; + + int64_t num_kv_heads_{0}; + + int64_t head_size_{0}; + + int64_t rotary_dim_{0}; + + bool is_neox_style_{true}; + + Tensor::Shape query_shape_; + + Tensor::Shape key_shape_; + + Tensor::Shape cos_sin_cache_shape_; + + Tensor::Shape query_out_shape_; + + Tensor::Shape key_out_shape_; + + Tensor::Strides query_strides_; + + Tensor::Strides key_strides_; + + Tensor::Strides query_out_strides_; + + Tensor::Strides key_out_strides_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/operator.h b/src/operator.h index 76efd7a..72e8337 100644 --- a/src/operator.h +++ b/src/operator.h @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -176,10 +177,10 @@ class Operator : public OperatorBase { auto it{cache.find(key)}; if (it == cache.end()) { - it = cache - .emplace(std::move(key), - make(config, std::forward(args)...)) - .first; + // Pass args as lvalue refs so they remain valid for the operator() call + // below. Forwarding rvalue temporaries into make() would leave the args + // in a moved-from (empty) state before operator() can use them. + it = cache.emplace(std::move(key), make(config, args...)).first; } auto& op{it->second}; diff --git a/src/pybind11_utils.h b/src/pybind11_utils.h index 0f5e73b..766b6ea 100644 --- a/src/pybind11_utils.h +++ b/src/pybind11_utils.h @@ -116,6 +116,12 @@ inline Tensor TensorFromPybind11Handle(py::handle obj) { return Tensor{data, std::move(shape), dtype, device, std::move(strides)}; } +inline std::optional OptionalTensorFromPybind11Handle( + const std::optional& obj) { + if (!obj.has_value()) return std::nullopt; + return TensorFromPybind11Handle(*obj); +} + } // namespace infini::ops #endif From 08e0d6aaeb008a3203f20eeae196019dbb6e32de Mon Sep 17 00:00:00 2001 From: zhangyue Date: Wed, 8 Apr 2026 11:08:51 +0800 Subject: [PATCH 2/7] style(ascend): apply `clang-format` to framework headers --- src/ascend/common.h | 60 +++++++++++++++---------------- src/ascend/data_type_.h | 69 +++++++++++++++++++++--------------- src/ascend/workspace_pool_.h | 46 ++++++++++++------------ src/base/add_rms_norm.h | 5 ++- src/base/flash_attention.h | 40 +++++++++------------ src/base/matmul.h | 3 +- src/base/reshape_and_cache.h | 16 ++++----- 7 files changed, 118 insertions(+), 121 deletions(-) diff --git a/src/ascend/common.h b/src/ascend/common.h index 3dbeeae..f5ecb1a 100644 --- a/src/ascend/common.h +++ b/src/ascend/common.h @@ -18,39 +18,37 @@ namespace infini::ops::ascend { // and Matmul to express a transpose via the view. inline aclTensor* buildAclTensor(const Tensor& t, bool transpose_last2 = false) { - std::vector shape(t.shape().begin(), t.shape().end()); - std::vector strides(t.strides().begin(), t.strides().end()); - - if (transpose_last2 && shape.size() >= 2) { - auto n = shape.size(); - std::swap(shape[n - 2], shape[n - 1]); - std::swap(strides[n - 2], strides[n - 1]); + std::vector shape(t.shape().begin(), t.shape().end()); + std::vector strides(t.strides().begin(), t.strides().end()); + + if (transpose_last2 && shape.size() >= 2) { + auto n = shape.size(); + std::swap(shape[n - 2], shape[n - 1]); + std::swap(strides[n - 2], strides[n - 1]); + } + + // Compute the minimum physical storage needed for this strided view. + // For contiguous tensors this equals numel(); for non-contiguous (gapped) + // tensors it may be larger; for broadcast (stride-0) tensors it may be + // smaller. Passing the view shape as the storage shape causes + // "ViewShape overlap" errors in ACLNN for non-contiguous inputs. + int64_t storage_elems = 1; + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] == 0) { + storage_elems = 0; + break; } - - // Compute the minimum physical storage needed for this strided view. - // For contiguous tensors this equals numel(); for non-contiguous (gapped) - // tensors it may be larger; for broadcast (stride-0) tensors it may be - // smaller. Passing the view shape as the storage shape causes - // "ViewShape overlap" errors in ACLNN for non-contiguous inputs. - int64_t storage_elems = 1; - for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] == 0) { storage_elems = 0; break; } - if (strides[i] > 0 && shape[i] > 1) { - storage_elems += static_cast(shape[i] - 1) * strides[i]; - } + if (strides[i] > 0 && shape[i] > 1) { + storage_elems += static_cast(shape[i] - 1) * strides[i]; } - std::vector storage_shape = {storage_elems}; - - return aclCreateTensor( - shape.data(), - static_cast(shape.size()), - toAclDtype(t.dtype()), - strides.data(), - /*storageOffset=*/0, - ACL_FORMAT_ND, - storage_shape.data(), - static_cast(storage_shape.size()), - const_cast(t.data())); + } + std::vector storage_shape = {storage_elems}; + + return aclCreateTensor( + shape.data(), static_cast(shape.size()), toAclDtype(t.dtype()), + strides.data(), + /*storageOffset=*/0, ACL_FORMAT_ND, storage_shape.data(), + static_cast(storage_shape.size()), const_cast(t.data())); } } // namespace infini::ops::ascend diff --git a/src/ascend/data_type_.h b/src/ascend/data_type_.h index 0574b23..08b1541 100644 --- a/src/ascend/data_type_.h +++ b/src/ascend/data_type_.h @@ -10,39 +10,50 @@ namespace infini::ops::ascend { inline aclDataType toAclDtype(DataType dt) { - switch (dt) { - case DataType::kFloat16: return ACL_FLOAT16; - case DataType::kBFloat16: return ACL_BF16; - case DataType::kFloat32: return ACL_FLOAT; - case DataType::kInt8: return ACL_INT8; - case DataType::kInt16: return ACL_INT16; - case DataType::kInt32: return ACL_INT32; - case DataType::kInt64: return ACL_INT64; - case DataType::kUInt8: return ACL_UINT8; - case DataType::kUInt16: return ACL_UINT16; - case DataType::kUInt32: return ACL_UINT32; - case DataType::kUInt64: return ACL_UINT64; - default: - assert(false && "unsupported dtype for Ascend backend"); - return ACL_DT_UNDEFINED; - } + switch (dt) { + case DataType::kFloat16: + return ACL_FLOAT16; + case DataType::kBFloat16: + return ACL_BF16; + case DataType::kFloat32: + return ACL_FLOAT; + case DataType::kInt8: + return ACL_INT8; + case DataType::kInt16: + return ACL_INT16; + case DataType::kInt32: + return ACL_INT32; + case DataType::kInt64: + return ACL_INT64; + case DataType::kUInt8: + return ACL_UINT8; + case DataType::kUInt16: + return ACL_UINT16; + case DataType::kUInt32: + return ACL_UINT32; + case DataType::kUInt64: + return ACL_UINT64; + default: + assert(false && "unsupported dtype for Ascend backend"); + return ACL_DT_UNDEFINED; + } } // Returns true for integer (signed or unsigned) DataType values. inline bool isIntegerDtype(DataType dt) { - switch (dt) { - case DataType::kInt8: - case DataType::kInt16: - case DataType::kInt32: - case DataType::kInt64: - case DataType::kUInt8: - case DataType::kUInt16: - case DataType::kUInt32: - case DataType::kUInt64: - return true; - default: - return false; - } + switch (dt) { + case DataType::kInt8: + case DataType::kInt16: + case DataType::kInt32: + case DataType::kInt64: + case DataType::kUInt8: + case DataType::kUInt16: + case DataType::kUInt32: + case DataType::kUInt64: + return true; + default: + return false; + } } } // namespace infini::ops::ascend diff --git a/src/ascend/workspace_pool_.h b/src/ascend/workspace_pool_.h index bd305fe..a44070e 100644 --- a/src/ascend/workspace_pool_.h +++ b/src/ascend/workspace_pool_.h @@ -10,42 +10,42 @@ namespace infini::ops::ascend { struct WorkspaceArena { - void* buf = nullptr; - uint64_t capacity = 0; + void* buf = nullptr; + uint64_t capacity = 0; }; class WorkspacePool { public: - WorkspaceArena& ensure(aclrtStream stream, uint64_t needed) { - std::lock_guard lock(mutex_); - auto& arena = arenas_[stream]; - if (needed <= arena.capacity) return arena; - if (arena.capacity > 0) { - aclrtSynchronizeStream(stream); - aclrtFree(arena.buf); - } - if (needed > 0) { - aclrtMalloc(&arena.buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY); - } - arena.capacity = needed; - return arena; + WorkspaceArena& ensure(aclrtStream stream, uint64_t needed) { + std::lock_guard lock(mutex_); + auto& arena = arenas_[stream]; + if (needed <= arena.capacity) return arena; + if (arena.capacity > 0) { + aclrtSynchronizeStream(stream); + aclrtFree(arena.buf); } + if (needed > 0) { + aclrtMalloc(&arena.buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY); + } + arena.capacity = needed; + return arena; + } - ~WorkspacePool() { - for (auto& [stream, arena] : arenas_) { - if (arena.capacity > 0) aclrtFree(arena.buf); - } + ~WorkspacePool() { + for (auto& [stream, arena] : arenas_) { + if (arena.capacity > 0) aclrtFree(arena.buf); } + } private: - std::unordered_map arenas_; + std::unordered_map arenas_; - std::mutex mutex_; + std::mutex mutex_; }; inline WorkspacePool& workspacePool() { - static WorkspacePool pool; - return pool; + static WorkspacePool pool; + return pool; } } // namespace infini::ops::ascend diff --git a/src/base/add_rms_norm.h b/src/base/add_rms_norm.h index b8315af..8243a53 100644 --- a/src/base/add_rms_norm.h +++ b/src/base/add_rms_norm.h @@ -26,9 +26,8 @@ class AddRmsNorm : public Operator { assert(x1.dtype() == x_out.dtype()); } - virtual void operator()(const Tensor x1, const Tensor x2, - const Tensor gamma, float eps, Tensor y_out, - Tensor x_out) const = 0; + virtual void operator()(const Tensor x1, const Tensor x2, const Tensor gamma, + float eps, Tensor y_out, Tensor x_out) const = 0; protected: Tensor::Shape input_shape_; diff --git a/src/base/flash_attention.h b/src/base/flash_attention.h index 7820e55..734e9a2 100644 --- a/src/base/flash_attention.h +++ b/src/base/flash_attention.h @@ -11,18 +11,13 @@ namespace infini::ops { class FlashAttention : public Operator { public: - FlashAttention( - const Tensor query, const Tensor key, const Tensor value, - std::optional cu_seqlens_q, - std::optional cu_seqlens_kv, - std::optional block_table, - int64_t num_heads, int64_t num_kv_heads, int64_t head_size, - double scale, - bool causal, - int64_t window_left, - int64_t window_right, - int64_t block_size, - Tensor output) + FlashAttention(const Tensor query, const Tensor key, const Tensor value, + std::optional cu_seqlens_q, + std::optional cu_seqlens_kv, + std::optional block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + bool causal, int64_t window_left, int64_t window_right, + int64_t block_size, Tensor output) : num_tokens_{query.size(0)}, num_heads_{num_heads}, num_kv_heads_{num_kv_heads}, @@ -50,18 +45,15 @@ class FlashAttention : public Operator { "`FlashAttention` requires query to be 3D [T, N, D]"); } - virtual void operator()( - const Tensor query, const Tensor key, const Tensor value, - std::optional cu_seqlens_q, - std::optional cu_seqlens_kv, - std::optional block_table, - int64_t num_heads, int64_t num_kv_heads, int64_t head_size, - double scale, - bool causal, - int64_t window_left, - int64_t window_right, - int64_t block_size, - Tensor output) const = 0; + virtual void operator()(const Tensor query, const Tensor key, + const Tensor value, + std::optional cu_seqlens_q, + std::optional cu_seqlens_kv, + std::optional block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + bool causal, int64_t window_left, + int64_t window_right, int64_t block_size, + Tensor output) const = 0; protected: Tensor::Size num_tokens_{0}; diff --git a/src/base/matmul.h b/src/base/matmul.h index e988aa1..48812c4 100644 --- a/src/base/matmul.h +++ b/src/base/matmul.h @@ -11,8 +11,7 @@ class Matmul : public Operator { // trans_a / trans_b: if true, transpose the last two dims of a / b before // multiplying. These are constructor parameters so the CacheKey encodes // the transposition and distinct descriptors are cached for each combination. - Matmul(const Tensor a, const Tensor b, Tensor c, - bool trans_a, bool trans_b) + Matmul(const Tensor a, const Tensor b, Tensor c, bool trans_a, bool trans_b) : a_shape_{a.shape()}, b_shape_{b.shape()}, c_shape_{c.shape()}, diff --git a/src/base/reshape_and_cache.h b/src/base/reshape_and_cache.h index a53caca..5d0adfa 100644 --- a/src/base/reshape_and_cache.h +++ b/src/base/reshape_and_cache.h @@ -10,10 +10,8 @@ namespace infini::ops { class ReshapeAndCache : public Operator { public: - ReshapeAndCache( - const Tensor key, const Tensor value, - const Tensor kv_cache, const Tensor slot_mapping, - Tensor kv_cache_out) + ReshapeAndCache(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, Tensor kv_cache_out) : num_tokens_{key.size(0)}, num_kv_heads_{key.size(1)}, head_size_{key.size(2)}, @@ -30,15 +28,15 @@ class ReshapeAndCache : public Operator { assert(key.shape() == value.shape() && "`ReshapeAndCache` requires key and value same shape"); assert(kv_cache.ndim() == 5 && - "`ReshapeAndCache` requires kv_cache to be 5D [2, num_blocks, block_size, num_kv_heads, head_size]"); + "`ReshapeAndCache` requires kv_cache to be 5D [2, num_blocks, " + "block_size, num_kv_heads, head_size]"); assert(slot_mapping.ndim() == 1 && "`ReshapeAndCache` requires slot_mapping to be 1D"); } - virtual void operator()( - const Tensor key, const Tensor value, - const Tensor kv_cache, const Tensor slot_mapping, - Tensor kv_cache_out) const = 0; + virtual void operator()(const Tensor key, const Tensor value, + const Tensor kv_cache, const Tensor slot_mapping, + Tensor kv_cache_out) const = 0; protected: Tensor::Size num_tokens_{0}; From 88a437935a8fd7a7fa4672a81af641b5ffa60bce Mon Sep 17 00:00:00 2001 From: zhangyue Date: Wed, 8 Apr 2026 12:13:39 +0800 Subject: [PATCH 3/7] fix(ascend): adapt `Memcpy`/`Memset` arity, assert workspace alloc, remove missing include - Wrap `aclrtMemcpy` (5-arg) and `aclrtMemset` (4-arg) in lambdas to match the generic 4-arg / 3-arg calling convention used by examples. - Assert `aclrtMalloc` return value in `WorkspacePool::ensure()`. - Remove `ascend/gemm/kernel.h` include from `runtime_api.h` (file does not exist until the kernels commit). --- examples/runtime_api.h | 5 ----- src/ascend/runtime_.h | 9 +++++++-- src/ascend/workspace_pool_.h | 5 ++++- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/runtime_api.h b/examples/runtime_api.h index 8b63153..4c7469f 100644 --- a/examples/runtime_api.h +++ b/examples/runtime_api.h @@ -19,9 +19,6 @@ #elif WITH_MOORE #include "moore/gemm/mublas.h" #include "moore/runtime_.h" -#elif WITH_ASCEND -#include "ascend/gemm/kernel.h" -#include "ascend/runtime_.h" #elif WITH_CPU #include "cpu/gemm/gemm.h" #include "cpu/runtime_.h" @@ -41,8 +38,6 @@ using DefaultRuntimeUtils = Runtime; using DefaultRuntimeUtils = Runtime; #elif WITH_MOORE using DefaultRuntimeUtils = Runtime; -#elif WITH_ASCEND -using DefaultRuntimeUtils = Runtime; #elif WITH_CPU using DefaultRuntimeUtils = Runtime; #endif diff --git a/src/ascend/runtime_.h b/src/ascend/runtime_.h index 2918d5e..dca7425 100644 --- a/src/ascend/runtime_.h +++ b/src/ascend/runtime_.h @@ -23,13 +23,18 @@ struct Runtime static constexpr auto Free = aclrtFree; - static constexpr auto Memcpy = aclrtMemcpy; + static constexpr auto Memcpy = [](void* dst, const void* src, size_t count, + aclrtMemcpyKind kind) { + return aclrtMemcpy(dst, count, src, count, kind); + }; static constexpr auto MemcpyHostToDevice = ACL_MEMCPY_HOST_TO_DEVICE; static constexpr auto MemcpyDeviceToHost = ACL_MEMCPY_DEVICE_TO_HOST; - static constexpr auto Memset = aclrtMemset; + static constexpr auto Memset = [](void* ptr, int value, size_t count) { + return aclrtMemset(ptr, count, value, count); + }; }; static_assert(Runtime::Validate()); diff --git a/src/ascend/workspace_pool_.h b/src/ascend/workspace_pool_.h index a44070e..d97a20e 100644 --- a/src/ascend/workspace_pool_.h +++ b/src/ascend/workspace_pool_.h @@ -1,6 +1,7 @@ #ifndef INFINI_OPS_ASCEND_WORKSPACE_POOL__H_ #define INFINI_OPS_ASCEND_WORKSPACE_POOL__H_ +#include #include #include #include @@ -25,7 +26,9 @@ class WorkspacePool { aclrtFree(arena.buf); } if (needed > 0) { - aclrtMalloc(&arena.buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY); + assert(aclrtMalloc(&arena.buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY) == + ACL_SUCCESS && + "`WorkspacePool`: `aclrtMalloc` failed"); } arena.capacity = needed; return arena; From 4833eb94ffc4cd657cf3e78c6f68decee6f620ca Mon Sep 17 00:00:00 2001 From: zhangyue Date: Wed, 8 Apr 2026 12:25:42 +0800 Subject: [PATCH 4/7] feat(ascend): add GEMM kernel, NPU test infra, and example integration - Add Ascend GEMM specialization using `aclnnAddmm`/`aclnnBaddbmm`. - Add `get_npu_stream()` helper and NPU device detection in test utils. - Add `skip_unsupported_dtype` fixture for Ascend in conftest. - Update `runtime_api.h` with Ascend backend entry. --- examples/runtime_api.h | 5 +++ src/ascend/gemm/kernel.h | 80 ++++++++++++++++++++++++++++++++++++++++ tests/conftest.py | 19 ++++++++++ tests/test_gemm.py | 28 ++++++++------ tests/utils.py | 14 +++++++ 5 files changed, 135 insertions(+), 11 deletions(-) create mode 100644 src/ascend/gemm/kernel.h diff --git a/examples/runtime_api.h b/examples/runtime_api.h index 4c7469f..8b63153 100644 --- a/examples/runtime_api.h +++ b/examples/runtime_api.h @@ -19,6 +19,9 @@ #elif WITH_MOORE #include "moore/gemm/mublas.h" #include "moore/runtime_.h" +#elif WITH_ASCEND +#include "ascend/gemm/kernel.h" +#include "ascend/runtime_.h" #elif WITH_CPU #include "cpu/gemm/gemm.h" #include "cpu/runtime_.h" @@ -38,6 +41,8 @@ using DefaultRuntimeUtils = Runtime; using DefaultRuntimeUtils = Runtime; #elif WITH_MOORE using DefaultRuntimeUtils = Runtime; +#elif WITH_ASCEND +using DefaultRuntimeUtils = Runtime; #elif WITH_CPU using DefaultRuntimeUtils = Runtime; #endif diff --git a/src/ascend/gemm/kernel.h b/src/ascend/gemm/kernel.h new file mode 100644 index 0000000..ceed55a --- /dev/null +++ b/src/ascend/gemm/kernel.h @@ -0,0 +1,80 @@ +#ifndef INFINI_OPS_ASCEND_GEMM_KERNEL_H_ +#define INFINI_OPS_ASCEND_GEMM_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_addmm.h" +#include "aclnnop/aclnn_baddbmm.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/gemm.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Gemm { + public: + Operator(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, std::optional trans_a, + std::optional trans_b, Tensor c) + : Gemm(a, b, alpha, beta, trans_a, trans_b, c), + batched_{batch_count_ > 1}, + alpha_val_{alpha.value_or(1.0f)}, + beta_val_{beta.value_or(1.0f)} { + alpha_scalar_ = aclCreateScalar(&alpha_val_, ACL_FLOAT); + beta_scalar_ = aclCreateScalar(&beta_val_, ACL_FLOAT); + } + + ~Operator() { + aclDestroyScalar(alpha_scalar_); + aclDestroyScalar(beta_scalar_); + } + + void operator()(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, std::optional trans_a, + std::optional trans_b, Tensor c) const override { + auto stream = static_cast(stream_); + + auto t_self = ascend::buildAclTensor(c); + auto t_a = ascend::buildAclTensor(a, trans_a_); + auto t_b = ascend::buildAclTensor(b, trans_b_); + auto t_out = ascend::buildAclTensor(c); + + uint64_t ws_needed = 0; + aclOpExecutor* executor = nullptr; + + if (batched_) { + aclnnBaddbmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_needed, + &executor); + } else { + aclnnAddmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, alpha_scalar_, + t_out, 0, &ws_needed, &executor); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_needed); + + if (batched_) { + aclnnBaddbmm(arena.buf, ws_needed, executor, stream); + } else { + aclnnAddmm(arena.buf, ws_needed, executor, stream); + } + + aclDestroyTensor(t_self); + aclDestroyTensor(t_a); + aclDestroyTensor(t_b); + aclDestroyTensor(t_out); + } + + private: + bool batched_; + float alpha_val_; + float beta_val_; + aclScalar* alpha_scalar_ = nullptr; + aclScalar* beta_scalar_ = nullptr; +}; + +} // namespace infini::ops + +#endif diff --git a/tests/conftest.py b/tests/conftest.py index 44654c3..8fb9f09 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -38,6 +38,25 @@ def set_seed_per_test(request): _set_random_seed(_hash(_test_case_path_from_request(request))) +_NPU_UNSUPPORTED_DTYPES = {torch.float64} + +# torch_npu does not implement random number generation for uint16/uint32/uint64. +for _bits in (16, 32, 64): + _t = getattr(torch, f"uint{_bits}", None) + if _t is not None: + _NPU_UNSUPPORTED_DTYPES.add(_t) + + +@pytest.fixture(autouse=True) +def skip_unsupported_dtype(request): + if not hasattr(request.node, "callspec"): + return + params = request.node.callspec.params + + if params.get("device") == "npu" and params.get("dtype") in _NPU_UNSUPPORTED_DTYPES: + pytest.skip(f"{params['dtype']} not supported on Ascend 910B") + + def _set_random_seed(seed): random.seed(seed) torch.manual_seed(seed) diff --git a/tests/test_gemm.py b/tests/test_gemm.py index 40ed35d..af8b44f 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, randn_strided +from tests.utils import Payload, get_npu_stream, randn_strided @pytest.mark.auto_act_and_assert @@ -84,16 +84,22 @@ def test_gemm( def _gemm(a, b, alpha, beta, trans_a, trans_b, c, implementation_index=0): - infini.ops.gemm( - a, - b, - alpha, - beta, - trans_a, - trans_b, - c, - implementation_index=implementation_index, - ) + if a.device.type == "npu": + infini.ops.gemm( + a, b, alpha, beta, trans_a, trans_b, c, + stream=get_npu_stream(a), + ) + else: + infini.ops.gemm( + a, + b, + alpha, + beta, + trans_a, + trans_b, + c, + implementation_index=implementation_index, + ) return c diff --git a/tests/utils.py b/tests/utils.py index aa4ee42..8412cd6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -32,12 +32,18 @@ def get_available_devices(): if hasattr(torch, "musa") and torch.musa.is_available(): devices.append("musa") + if hasattr(torch, "npu") and torch.npu.is_available(): + devices.append("npu") + return tuple(devices) with contextlib.suppress(ImportError, ModuleNotFoundError): import torch_mlu # noqa: F401 +with contextlib.suppress(ImportError, ModuleNotFoundError): + import torch_npu # noqa: F401 + def empty_strided(shape, strides, *, dtype=None, device=None): if strides is None: @@ -76,6 +82,14 @@ def randint_strided(low, high, shape, strides, *, dtype=None, device=None): return output +def get_npu_stream(tensor): + """Return the current NPU stream handle for `tensor`, or 0 on other devices.""" + if tensor.device.type != "npu": + return 0 + + return torch.npu.current_stream().npu_stream + + def clone_strided(input): output = empty_strided( input.size(), input.stride(), dtype=input.dtype, device=input.device From 21533e3e2cd8bcb3cabeb432b37a322e8cd99a6e Mon Sep 17 00:00:00 2001 From: zhangyue Date: Wed, 8 Apr 2026 15:03:20 +0800 Subject: [PATCH 5/7] fix(ascend): move `aclrtMalloc` out of `assert()` in `WorkspacePool` The `aclrtMalloc` call was the sole expression inside `assert()`, so it was compiled away in release builds (NDEBUG). This left the workspace buffer null, causing `aclnnAddmm` to return ACLNN_ERR_PARAM_NULLPTR (161001) for any operation that requires workspace (e.g. alpha != 1.0). --- src/ascend/workspace_pool_.h | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/ascend/workspace_pool_.h b/src/ascend/workspace_pool_.h index d97a20e..bac2479 100644 --- a/src/ascend/workspace_pool_.h +++ b/src/ascend/workspace_pool_.h @@ -26,9 +26,8 @@ class WorkspacePool { aclrtFree(arena.buf); } if (needed > 0) { - assert(aclrtMalloc(&arena.buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY) == - ACL_SUCCESS && - "`WorkspacePool`: `aclrtMalloc` failed"); + auto ret = aclrtMalloc(&arena.buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY); + assert(ret == ACL_SUCCESS && "`WorkspacePool`: `aclrtMalloc` failed"); } arena.capacity = needed; return arena; From cec3de851385e8d98c4e51331e76cfa3a5e09256 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Wed, 8 Apr 2026 15:16:05 +0800 Subject: [PATCH 6/7] fix(nvidia): restore `CUDA::cublasLt` link dependency --- src/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 17abb8c..a178836 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -40,7 +40,7 @@ if(WITH_NVIDIA) target_sources(infiniops PRIVATE ${NVIDIA_SOURCES}) find_package(CUDAToolkit REQUIRED) - target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cuda_driver) + target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cuda_driver) list(APPEND DEVICE_LIST "nvidia") set_target_properties(infiniops PROPERTIES From cb2bab3dec8b7f05efa6cce7ac0ce00aa126f1b6 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Wed, 8 Apr 2026 16:06:04 +0800 Subject: [PATCH 7/7] feat(test): add `--devices` option to pytest for platform-name filtering --- tests/conftest.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 8fb9f09..344e452 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,12 @@ def pytest_addoption(parser): parser.addoption( "--benchmark", action="store_true", help="Run performance benchmarks." ) + parser.addoption( + "--devices", + nargs="+", + default=None, + help="Device(s) to test on (e.g., --devices ascend cpu). Accepts platform names (ascend, nvidia, cambricon, metax, moore, iluvatar) or PyTorch device types (npu, cuda, mlu, musa). Defaults to all available devices.", + ) def pytest_configure(config): @@ -62,6 +68,21 @@ def _set_random_seed(seed): torch.manual_seed(seed) +_PLATFORM_TO_TORCH_DEVICE = { + "nvidia": "cuda", + "iluvatar": "cuda", + "metax": "cuda", + "cambricon": "mlu", + "moore": "musa", + "ascend": "npu", +} + + +def _resolve_device(name): + """Map a platform name (e.g., ``ascend``) to a PyTorch device type (e.g., ``npu``).""" + return _PLATFORM_TO_TORCH_DEVICE.get(name, name) + + def pytest_generate_tests(metafunc): already_parametrized = _get_parametrized_args(metafunc) @@ -76,7 +97,17 @@ def pytest_generate_tests(metafunc): ) if "device" in metafunc.fixturenames and "device" not in already_parametrized: - metafunc.parametrize("device", get_available_devices()) + cli_devices = metafunc.config.getoption("--devices") + available = get_available_devices() + + if cli_devices: + devices = tuple( + d for d in (_resolve_device(x) for x in cli_devices) if d in available + ) + else: + devices = () + + metafunc.parametrize("device", devices or available) @pytest.hookimpl(tryfirst=True)