diff --git a/CMakeLists.txt b/CMakeLists.txt index b9e2deb..7f4c5cb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,7 @@ option(WITH_ILUVATAR "Enable Iluvatar GPU backend" OFF) option(WITH_METAX "Enable MetaX backend" OFF) option(WITH_CAMBRICON "Enable Cambricon backend" OFF) option(WITH_MOORE "Enable Moore backend" OFF) +option(WITH_ASCEND "Enable Ascend backend" OFF) option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF) option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF) @@ -71,20 +72,25 @@ if(AUTO_DETECT_DEVICES) set(WITH_MOORE OFF) set(WITH_MOORE OFF CACHE BOOL "Enable Moore backend" FORCE) endif() + + if(DEFINED ENV{ASCEND_HOME_PATH} OR EXISTS "/dev/davinci0") + set(WITH_ASCEND ON) + message(STATUS "Auto-detected Ascend environment.") + endif() endif() include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) # Only one CUDA-like GPU backend can be enabled at a time. set(_gpu_backend_count 0) -foreach(_gpu_backend WITH_NVIDIA WITH_ILUVATAR WITH_METAX WITH_MOORE) +foreach(_gpu_backend WITH_NVIDIA WITH_ILUVATAR WITH_METAX WITH_MOORE WITH_ASCEND) if(${_gpu_backend}) math(EXPR _gpu_backend_count "${_gpu_backend_count} + 1") endif() endforeach() if(_gpu_backend_count GREATER 1) - message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_METAX`, and `WITH_MOORE` are mutually exclusive. Build one GPU backend at a time.") + message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_METAX`, `WITH_MOORE`, and `WITH_ASCEND` are mutually exclusive. Build one GPU backend at a time.") endif() if(WITH_NVIDIA) @@ -178,8 +184,23 @@ if(WITH_CAMBRICON) find_library(CAMBRICON_PAPI_LIB NAMES cnpapi HINTS "${NEUWARE_HOME}/lib64" REQUIRED) endif() +if(WITH_ASCEND) + add_compile_definitions(WITH_ASCEND=1) + if(NOT DEFINED ASCEND_HOME) + if(DEFINED ENV{ASCEND_HOME_PATH} AND NOT "$ENV{ASCEND_HOME_PATH}" STREQUAL "") + set(ASCEND_HOME "$ENV{ASCEND_HOME_PATH}" CACHE PATH "Ascend toolkit root") + else() + set(ASCEND_HOME "/usr/local/Ascend/ascend-toolkit/latest" CACHE PATH "Ascend toolkit root") + endif() + endif() + if(NOT EXISTS "${ASCEND_HOME}") + message(FATAL_ERROR "`WITH_ASCEND` is ON but `${ASCEND_HOME}` was not found. Set ASCEND_HOME_PATH.") + endif() + message(STATUS "Using Ascend from `${ASCEND_HOME}`.") +endif() + # If all other platforms are not enabled, CPU is enabled by default. -if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE AND NOT WITH_CAMBRICON) +if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE AND NOT WITH_CAMBRICON AND NOT WITH_ASCEND) add_compile_definitions(WITH_CPU=1) endif() diff --git a/examples/runtime_api.h b/examples/runtime_api.h index 4c7469f..8b63153 100644 --- a/examples/runtime_api.h +++ b/examples/runtime_api.h @@ -19,6 +19,9 @@ #elif WITH_MOORE #include "moore/gemm/mublas.h" #include "moore/runtime_.h" +#elif WITH_ASCEND +#include "ascend/gemm/kernel.h" +#include "ascend/runtime_.h" #elif WITH_CPU #include "cpu/gemm/gemm.h" #include "cpu/runtime_.h" @@ -38,6 +41,8 @@ using DefaultRuntimeUtils = Runtime; using DefaultRuntimeUtils = Runtime; #elif WITH_MOORE using DefaultRuntimeUtils = Runtime; +#elif WITH_ASCEND +using DefaultRuntimeUtils = Runtime; #elif WITH_CPU using DefaultRuntimeUtils = Runtime; #endif diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 5aa8896..4580bed 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -1,6 +1,7 @@ import argparse import json import pathlib +import re import shutil import subprocess import textwrap @@ -91,26 +92,54 @@ def __init__(self, name, constructors, calls): self.calls = calls +def _find_optional_tensor_params(op_name): + """Return a set of parameter names declared as `std::optional` in + the base header. libclang resolves the type to ``int`` when the STL + headers are not fully available, so we fall back to a regex scan of the + source text. + """ + source = (_BASE_DIR / f"{op_name}.h").read_text() + return set(re.findall(r"std::optional\s+(\w+)", source)) + + def _generate_pybind11(operator): + optional_tensor_params = _find_optional_tensor_params(operator.name) + + def _is_optional_tensor(arg): + if arg.spelling in optional_tensor_params: + return True + return "std::optional" in arg.type.spelling and "Tensor" in arg.type.spelling + def _generate_params(node): - return ( - ", ".join( - f"{arg.type.spelling} {arg.spelling}" - for arg in node.get_arguments() - if arg.spelling != "stream" - ) - .replace("const Tensor", "py::object") - .replace("Tensor", "py::object") - ) + parts = [] + + for arg in node.get_arguments(): + if arg.spelling == "stream": + continue + if _is_optional_tensor(arg): + parts.append(f"std::optional {arg.spelling}") + else: + param = arg.type.spelling.replace("const Tensor", "py::object").replace( + "Tensor", "py::object" + ) + parts.append(f"{param} {arg.spelling}") + + return ", ".join(parts) def _generate_arguments(node): - return ", ".join( - f"TensorFromPybind11Handle({arg.spelling})" - if "Tensor" in arg.type.spelling - else arg.spelling - for arg in node.get_arguments() - if arg.spelling != "stream" - ) + args = [] + + for arg in node.get_arguments(): + if arg.spelling == "stream": + continue + if _is_optional_tensor(arg): + args.append(f"OptionalTensorFromPybind11Handle({arg.spelling})") + elif "Tensor" in arg.type.spelling: + args.append(f"TensorFromPybind11Handle({arg.spelling})") + else: + args.append(arg.spelling) + + return ", ".join(args) op_name = operator.name @@ -134,18 +163,24 @@ def _generate_call(op_name, call, method=True): if not method: params = ( - f"{call_params}, std::size_t implementation_index" + f"{call_params}, std::size_t implementation_index, std::uintptr_t stream" if call_params - else "std::size_t implementation_index" + else "std::size_t implementation_index, std::uintptr_t stream" ) py_args = _generate_py_args(call) py_args_str = f"{py_args}, " if py_args else "" - return f""" m.def("{op_name}", []({params}) {{ - Config config; - config.set_implementation_index(implementation_index); - return Self::call({{}}, config, {call_args}); - }}, {py_args_str}py::kw_only(), py::arg("implementation_index") = 0);""" + return ( + f' m.def("{op_name}", []({params}) {{\n' + f" Config config;\n" + f" config.set_implementation_index(implementation_index);\n" + f" Handle handle;\n" + f" if (stream) {{\n" + f" handle.set_stream(reinterpret_cast(stream));\n" + f" }}\n" + f" return Self::call(handle, config, {call_args});\n" + f' }}, {py_args_str}py::kw_only(), py::arg("implementation_index") = 0, py::arg("stream") = 0);' + ) return f""" .def("__call__", [](const Self& self, {call_params}) {{ return static_cast&>(self)({call_args}); @@ -169,6 +204,8 @@ def _generate_call(op_name, call, method=True): #include "base/{op_name}.h" #include "config.h" +#include "handle.h" +#include "operator.h" #include "pybind11_utils.h" namespace py = pybind11; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0b56341..a178836 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -172,10 +172,60 @@ if(WITH_CAMBRICON) list(APPEND DEVICE_LIST "cambricon") endif() +if(WITH_ASCEND) + # ASCEND_HOME is set by the top-level CMakeLists.txt. + file(GLOB_RECURSE ASCEND_SOURCES CONFIGURE_DEPENDS + "ascend/*.cc" + "ascend/*.cpp" + ) + # Exclude kernel_impl.cpp — AscendC device code, not compiled by the host C++ compiler. + list(FILTER ASCEND_SOURCES EXCLUDE REGEX ".*kernel_impl\\.cpp$") + + target_compile_definitions(infiniops PUBLIC WITH_ASCEND=1) + target_sources(infiniops PRIVATE ${ASCEND_SOURCES}) + + # Resolve the driver lib dir two levels above the toolkit root. + get_filename_component(ASCEND_ROOT "${ASCEND_HOME}/../.." ABSOLUTE) + + # Prefer the real driver HAL; fall back to the toolkit stub for build-only + # environments (e.g., Docker CI images without hardware drivers installed). + # CANN <= 8.0: stub at runtime/lib64/stub/; CANN >= 8.5: devlib/-linux/devlib/. + set(ASCEND_HAL_REAL "${ASCEND_ROOT}/driver/lib64/driver/libascend_hal.so") + set(ASCEND_HAL_STUB "${ASCEND_HOME}/runtime/lib64/stub/libascend_hal.so") + set(ASCEND_HAL_DEVLIB "${ASCEND_HOME}/${CMAKE_SYSTEM_PROCESSOR}-linux/devlib/libascend_hal.so") + if(EXISTS "${ASCEND_HAL_REAL}") + set(ASCEND_HAL_LIB "${ASCEND_HAL_REAL}") + elseif(EXISTS "${ASCEND_HAL_STUB}") + set(ASCEND_HAL_LIB "${ASCEND_HAL_STUB}") + message(STATUS "ascend_hal: driver not found, using stub for linking") + elseif(EXISTS "${ASCEND_HAL_DEVLIB}") + set(ASCEND_HAL_LIB "${ASCEND_HAL_DEVLIB}") + message(STATUS "ascend_hal: driver not found, using devlib for linking") + else() + message(FATAL_ERROR "libascend_hal.so not found (tried ${ASCEND_HAL_REAL}, ${ASCEND_HAL_STUB}, and ${ASCEND_HAL_DEVLIB})") + endif() + + target_include_directories(infiniops PUBLIC + "${ASCEND_HOME}/include" + "${ASCEND_HOME}/include/aclnn" + "${ASCEND_HOME}/include/aclnnop") + target_link_libraries(infiniops PUBLIC + "${ASCEND_HOME}/lib64/libascendcl.so" + "${ASCEND_HOME}/lib64/libnnopbase.so" + "${ASCEND_HOME}/lib64/libopapi.so" + "${ASCEND_HAL_LIB}") + + list(APPEND DEVICE_LIST "ascend") +endif() + target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) if(GENERATE_PYTHON_BINDINGS) find_package(Python COMPONENTS Interpreter REQUIRED) + # Always regenerate bindings so the included kernel headers match the + # active device list. Stale generated files (e.g., committed for one + # platform) would omit specializations for other enabled backends, + # causing link-time or runtime failures. execute_process( COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/generate_wrappers.py --devices ${DEVICE_LIST} WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} diff --git a/src/ascend/common.h b/src/ascend/common.h new file mode 100644 index 0000000..caa1062 --- /dev/null +++ b/src/ascend/common.h @@ -0,0 +1,56 @@ +#ifndef INFINI_OPS_ASCEND_COMMON_H_ +#define INFINI_OPS_ASCEND_COMMON_H_ + +#include +#include + +#include "acl/acl.h" +#include "aclnn/acl_meta.h" +#include "ascend/data_type_.h" +#include "tensor.h" + +namespace infini::ops::ascend { + +// Build an aclTensor descriptor from an InfiniOps Tensor. +// +// When `transpose_last2` is true the last two dimensions are swapped in the +// descriptor (shape and strides) without copying data. This is used by GEMM +// and Matmul to express a transpose via the view. +inline aclTensor* buildAclTensor(const Tensor& t, + bool transpose_last2 = false) { + std::vector shape(t.shape().begin(), t.shape().end()); + std::vector strides(t.strides().begin(), t.strides().end()); + + if (transpose_last2 && shape.size() >= 2) { + auto n = shape.size(); + std::swap(shape[n - 2], shape[n - 1]); + std::swap(strides[n - 2], strides[n - 1]); + } + + // Compute the minimum physical storage needed for this strided view. + // For contiguous tensors this equals `numel()`; for non-contiguous (gapped) + // tensors it may be larger; for broadcast (stride-0) tensors it may be + // smaller. Passing the view shape as the storage shape causes + // "ViewShape overlap" errors in ACLNN for non-contiguous inputs. + int64_t storage_elems = 1; + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] == 0) { + storage_elems = 0; + break; + } + if (strides[i] > 0 && shape[i] > 1) { + storage_elems += static_cast(shape[i] - 1) * strides[i]; + } + } + std::vector storage_shape = {storage_elems}; + + return aclCreateTensor( + shape.data(), static_cast(shape.size()), toAclDtype(t.dtype()), + strides.data(), + /*storageOffset=*/0, ACL_FORMAT_ND, storage_shape.data(), + static_cast(storage_shape.size()), const_cast(t.data())); +} + +} // namespace infini::ops::ascend + +#endif diff --git a/src/ascend/data_type_.h b/src/ascend/data_type_.h new file mode 100644 index 0000000..08b1541 --- /dev/null +++ b/src/ascend/data_type_.h @@ -0,0 +1,61 @@ +#ifndef INFINI_OPS_ASCEND_DATA_TYPE__H_ +#define INFINI_OPS_ASCEND_DATA_TYPE__H_ + +#include + +#include "acl/acl.h" +#include "ascend/device_.h" +#include "data_type.h" + +namespace infini::ops::ascend { + +inline aclDataType toAclDtype(DataType dt) { + switch (dt) { + case DataType::kFloat16: + return ACL_FLOAT16; + case DataType::kBFloat16: + return ACL_BF16; + case DataType::kFloat32: + return ACL_FLOAT; + case DataType::kInt8: + return ACL_INT8; + case DataType::kInt16: + return ACL_INT16; + case DataType::kInt32: + return ACL_INT32; + case DataType::kInt64: + return ACL_INT64; + case DataType::kUInt8: + return ACL_UINT8; + case DataType::kUInt16: + return ACL_UINT16; + case DataType::kUInt32: + return ACL_UINT32; + case DataType::kUInt64: + return ACL_UINT64; + default: + assert(false && "unsupported dtype for Ascend backend"); + return ACL_DT_UNDEFINED; + } +} + +// Returns true for integer (signed or unsigned) DataType values. +inline bool isIntegerDtype(DataType dt) { + switch (dt) { + case DataType::kInt8: + case DataType::kInt16: + case DataType::kInt32: + case DataType::kInt64: + case DataType::kUInt8: + case DataType::kUInt16: + case DataType::kUInt32: + case DataType::kUInt64: + return true; + default: + return false; + } +} + +} // namespace infini::ops::ascend + +#endif diff --git a/src/ascend/device_.h b/src/ascend/device_.h new file mode 100644 index 0000000..b4ec934 --- /dev/null +++ b/src/ascend/device_.h @@ -0,0 +1,16 @@ +#ifndef INFINI_OPS_ASCEND_DEVICE__H_ +#define INFINI_OPS_ASCEND_DEVICE__H_ + +// NOTE: Cannot use `#include "device.h"` here — GCC resolves quoted includes +// relative to the current file first, and `src/ascend/` used to contain a +// `device.h`. Use `data_type.h` which transitively pulls in `src/device.h`. +#include "data_type.h" + +namespace infini::ops { + +template <> +struct DeviceEnabled : std::true_type {}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/gemm/kernel.h b/src/ascend/gemm/kernel.h new file mode 100644 index 0000000..5f32e27 --- /dev/null +++ b/src/ascend/gemm/kernel.h @@ -0,0 +1,84 @@ +#ifndef INFINI_OPS_ASCEND_GEMM_KERNEL_H_ +#define INFINI_OPS_ASCEND_GEMM_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_addmm.h" +#include "aclnnop/aclnn_baddbmm.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/gemm.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Gemm { + public: + Operator(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, std::optional trans_a, + std::optional trans_b, Tensor c) + : Gemm(a, b, alpha, beta, trans_a, trans_b, c), + batched_{batch_count_ > 1}, + alpha_val_{alpha.value_or(1.0f)}, + beta_val_{beta.value_or(1.0f)} { + alpha_scalar_ = aclCreateScalar(&alpha_val_, ACL_FLOAT); + beta_scalar_ = aclCreateScalar(&beta_val_, ACL_FLOAT); + } + + ~Operator() { + aclDestroyScalar(alpha_scalar_); + aclDestroyScalar(beta_scalar_); + } + + void operator()(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, std::optional trans_a, + std::optional trans_b, Tensor c) const override { + auto stream = static_cast(stream_); + + auto t_self = ascend::buildAclTensor(c); + auto t_a = ascend::buildAclTensor(a, trans_a_); + auto t_b = ascend::buildAclTensor(b, trans_b_); + auto t_out = ascend::buildAclTensor(c); + + uint64_t ws_needed = 0; + aclOpExecutor* executor = nullptr; + + if (batched_) { + aclnnBaddbmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_needed, + &executor); + } else { + aclnnAddmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, alpha_scalar_, + t_out, 0, &ws_needed, &executor); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_needed); + + if (batched_) { + aclnnBaddbmm(arena.buf, ws_needed, executor, stream); + } else { + aclnnAddmm(arena.buf, ws_needed, executor, stream); + } + + aclDestroyTensor(t_self); + aclDestroyTensor(t_a); + aclDestroyTensor(t_b); + aclDestroyTensor(t_out); + } + + private: + bool batched_; + + float alpha_val_; + + float beta_val_; + + aclScalar* alpha_scalar_ = nullptr; + + aclScalar* beta_scalar_ = nullptr; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/runtime_.h b/src/ascend/runtime_.h new file mode 100644 index 0000000..dca7425 --- /dev/null +++ b/src/ascend/runtime_.h @@ -0,0 +1,44 @@ +#ifndef INFINI_OPS_ASCEND_RUNTIME__H_ +#define INFINI_OPS_ASCEND_RUNTIME__H_ + +// clang-format off +#include "acl/acl.h" +// clang-format on + +#include "ascend/device_.h" +#include "runtime.h" + +namespace infini::ops { + +template <> +struct Runtime + : DeviceRuntime> { + using Stream = aclrtStream; + + static constexpr Device::Type kDeviceType = Device::Type::kAscend; + + static constexpr auto Malloc = [](void** ptr, size_t size) { + return aclrtMalloc(ptr, size, ACL_MEM_MALLOC_HUGE_FIRST); + }; + + static constexpr auto Free = aclrtFree; + + static constexpr auto Memcpy = [](void* dst, const void* src, size_t count, + aclrtMemcpyKind kind) { + return aclrtMemcpy(dst, count, src, count, kind); + }; + + static constexpr auto MemcpyHostToDevice = ACL_MEMCPY_HOST_TO_DEVICE; + + static constexpr auto MemcpyDeviceToHost = ACL_MEMCPY_DEVICE_TO_HOST; + + static constexpr auto Memset = [](void* ptr, int value, size_t count) { + return aclrtMemset(ptr, count, value, count); + }; +}; + +static_assert(Runtime::Validate()); + +} // namespace infini::ops + +#endif diff --git a/src/ascend/workspace_pool_.h b/src/ascend/workspace_pool_.h new file mode 100644 index 0000000..ebb670d --- /dev/null +++ b/src/ascend/workspace_pool_.h @@ -0,0 +1,56 @@ +#ifndef INFINI_OPS_ASCEND_WORKSPACE_POOL__H_ +#define INFINI_OPS_ASCEND_WORKSPACE_POOL__H_ + +#include +#include +#include +#include + +#include "acl/acl.h" + +namespace infini::ops::ascend { + +struct WorkspaceArena { + void* buf = nullptr; + + uint64_t capacity = 0; +}; + +class WorkspacePool { + public: + WorkspaceArena& ensure(aclrtStream stream, uint64_t needed) { + std::lock_guard lock(mutex_); + auto& arena = arenas_[stream]; + if (needed <= arena.capacity) return arena; + if (arena.capacity > 0) { + aclrtSynchronizeStream(stream); + aclrtFree(arena.buf); + } + if (needed > 0) { + auto ret = aclrtMalloc(&arena.buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY); + assert(ret == ACL_SUCCESS && "`WorkspacePool`: `aclrtMalloc` failed"); + } + arena.capacity = needed; + return arena; + } + + ~WorkspacePool() { + for (auto& [stream, arena] : arenas_) { + if (arena.capacity > 0) aclrtFree(arena.buf); + } + } + + private: + std::unordered_map arenas_; + + std::mutex mutex_; +}; + +inline WorkspacePool& workspacePool() { + static WorkspacePool pool; + return pool; +} + +} // namespace infini::ops::ascend + +#endif diff --git a/src/base/add_rms_norm.h b/src/base/add_rms_norm.h new file mode 100644 index 0000000..8243a53 --- /dev/null +++ b/src/base/add_rms_norm.h @@ -0,0 +1,50 @@ +#ifndef INFINI_OPS_BASE_ADD_RMS_NORM_H_ +#define INFINI_OPS_BASE_ADD_RMS_NORM_H_ + +#include +#include + +#include "operator.h" +#include "tensor.h" + +namespace infini::ops { + +class AddRmsNorm : public Operator { + public: + AddRmsNorm(const Tensor x1, const Tensor x2, const Tensor gamma, float eps, + Tensor y_out, Tensor x_out) + : input_shape_{x1.shape()}, + eps_{eps}, + dim_{x1.size(-1)}, + ndim_{x1.ndim()}, + batch_size_{ndim_ == 2 ? x1.size(-2) : x1.size(-3)}, + nhead_{ndim_ == 2 ? 1 : x1.size(-2)}, + rstd_shape_{static_cast(batch_size_), + static_cast(nhead_)} { + assert(x1.dtype() == x2.dtype()); + assert(x1.dtype() == y_out.dtype()); + assert(x1.dtype() == x_out.dtype()); + } + + virtual void operator()(const Tensor x1, const Tensor x2, const Tensor gamma, + float eps, Tensor y_out, Tensor x_out) const = 0; + + protected: + Tensor::Shape input_shape_; + + float eps_{1e-6f}; + + Tensor::Size dim_{0}; + + Tensor::Size ndim_{0}; + + Tensor::Size batch_size_{0}; + + Tensor::Size nhead_{1}; + + std::vector rstd_shape_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/flash_attention.h b/src/base/flash_attention.h new file mode 100644 index 0000000..734e9a2 --- /dev/null +++ b/src/base/flash_attention.h @@ -0,0 +1,104 @@ +#ifndef INFINI_OPS_BASE_FLASH_ATTENTION_H_ +#define INFINI_OPS_BASE_FLASH_ATTENTION_H_ + +#include +#include +#include + +#include "operator.h" + +namespace infini::ops { + +class FlashAttention : public Operator { + public: + FlashAttention(const Tensor query, const Tensor key, const Tensor value, + std::optional cu_seqlens_q, + std::optional cu_seqlens_kv, + std::optional block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + bool causal, int64_t window_left, int64_t window_right, + int64_t block_size, Tensor output) + : num_tokens_{query.size(0)}, + num_heads_{num_heads}, + num_kv_heads_{num_kv_heads}, + head_size_{head_size}, + scale_{scale}, + causal_{causal}, + window_left_{window_left}, + window_right_{window_right}, + block_size_{block_size}, + dtype_{query.dtype()}, + query_shape_{query.shape()}, + key_shape_{key.shape()}, + value_shape_{value.shape()}, + output_shape_{output.shape()}, + query_strides_{query.strides()}, + key_strides_{key.strides()}, + value_strides_{value.strides()}, + output_strides_{output.strides()}, + has_cu_seqlens_q_{cu_seqlens_q.has_value()}, + has_cu_seqlens_kv_{cu_seqlens_kv.has_value()}, + has_block_table_{block_table.has_value()} { + assert(num_heads % num_kv_heads == 0 && + "`FlashAttention` requires num_heads divisible by num_kv_heads"); + assert(query.ndim() == 3 && + "`FlashAttention` requires query to be 3D [T, N, D]"); + } + + virtual void operator()(const Tensor query, const Tensor key, + const Tensor value, + std::optional cu_seqlens_q, + std::optional cu_seqlens_kv, + std::optional block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + bool causal, int64_t window_left, + int64_t window_right, int64_t block_size, + Tensor output) const = 0; + + protected: + Tensor::Size num_tokens_{0}; + + int64_t num_heads_{0}; + + int64_t num_kv_heads_{0}; + + int64_t head_size_{0}; + + double scale_{0.0}; + + bool causal_{false}; + + int64_t window_left_{-1}; + + int64_t window_right_{-1}; + + int64_t block_size_{0}; + + const DataType dtype_; + + Tensor::Shape query_shape_; + + Tensor::Shape key_shape_; + + Tensor::Shape value_shape_; + + Tensor::Shape output_shape_; + + Tensor::Strides query_strides_; + + Tensor::Strides key_strides_; + + Tensor::Strides value_strides_; + + Tensor::Strides output_strides_; + + bool has_cu_seqlens_q_{false}; + + bool has_cu_seqlens_kv_{false}; + + bool has_block_table_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/matmul.h b/src/base/matmul.h new file mode 100644 index 0000000..071feae --- /dev/null +++ b/src/base/matmul.h @@ -0,0 +1,41 @@ +#ifndef INFINI_OPS_BASE_MATMUL_H_ +#define INFINI_OPS_BASE_MATMUL_H_ + +#include "operator.h" +#include "tensor.h" + +namespace infini::ops { + +class Matmul : public Operator { + public: + // `trans_a` / `trans_b`: If true, transpose the last two dims of `a` / `b` + // before multiplying. These are constructor parameters so the `CacheKey` + // encodes the transposition and distinct descriptors are cached for each + // combination. + Matmul(const Tensor a, const Tensor b, Tensor c, bool trans_a, bool trans_b) + : a_shape_{a.shape()}, + b_shape_{b.shape()}, + c_shape_{c.shape()}, + trans_a_{trans_a}, + trans_b_{trans_b} { + assert(a.dtype() == b.dtype()); + } + + virtual void operator()(const Tensor a, const Tensor b, Tensor c, + bool trans_a, bool trans_b) const = 0; + + protected: + Tensor::Shape a_shape_; + + Tensor::Shape b_shape_; + + Tensor::Shape c_shape_; + + bool trans_a_{false}; + + bool trans_b_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/reshape_and_cache.h b/src/base/reshape_and_cache.h new file mode 100644 index 0000000..5d0adfa --- /dev/null +++ b/src/base/reshape_and_cache.h @@ -0,0 +1,71 @@ +#ifndef INFINI_OPS_BASE_RESHAPE_AND_CACHE_H_ +#define INFINI_OPS_BASE_RESHAPE_AND_CACHE_H_ + +#include +#include + +#include "operator.h" + +namespace infini::ops { + +class ReshapeAndCache : public Operator { + public: + ReshapeAndCache(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, Tensor kv_cache_out) + : num_tokens_{key.size(0)}, + num_kv_heads_{key.size(1)}, + head_size_{key.size(2)}, + block_size_{kv_cache.size(2)}, + key_shape_{key.shape()}, + value_shape_{value.shape()}, + kv_cache_shape_{kv_cache.shape()}, + slot_mapping_shape_{slot_mapping.shape()}, + key_strides_{key.strides()}, + value_strides_{value.strides()}, + kv_cache_strides_{kv_cache.strides()}, + slot_mapping_strides_{slot_mapping.strides()}, + kv_cache_out_strides_{kv_cache_out.strides()} { + assert(key.shape() == value.shape() && + "`ReshapeAndCache` requires key and value same shape"); + assert(kv_cache.ndim() == 5 && + "`ReshapeAndCache` requires kv_cache to be 5D [2, num_blocks, " + "block_size, num_kv_heads, head_size]"); + assert(slot_mapping.ndim() == 1 && + "`ReshapeAndCache` requires slot_mapping to be 1D"); + } + + virtual void operator()(const Tensor key, const Tensor value, + const Tensor kv_cache, const Tensor slot_mapping, + Tensor kv_cache_out) const = 0; + + protected: + Tensor::Size num_tokens_{0}; + + Tensor::Size num_kv_heads_{0}; + + Tensor::Size head_size_{0}; + + Tensor::Size block_size_{0}; + + Tensor::Shape key_shape_; + + Tensor::Shape value_shape_; + + Tensor::Shape kv_cache_shape_; + + Tensor::Shape slot_mapping_shape_; + + Tensor::Strides key_strides_; + + Tensor::Strides value_strides_; + + Tensor::Strides kv_cache_strides_; + + Tensor::Strides slot_mapping_strides_; + + Tensor::Strides kv_cache_out_strides_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/rotary_embedding.h b/src/base/rotary_embedding.h new file mode 100644 index 0000000..70989fa --- /dev/null +++ b/src/base/rotary_embedding.h @@ -0,0 +1,80 @@ +#ifndef INFINI_OPS_BASE_ROTARY_EMBEDDING_H_ +#define INFINI_OPS_BASE_ROTARY_EMBEDDING_H_ + +#include +#include + +#include "operator.h" + +namespace infini::ops { + +class RotaryEmbedding : public Operator { + public: + RotaryEmbedding(const Tensor positions, const Tensor query, const Tensor key, + const Tensor cos_sin_cache, int64_t head_size, + int64_t rotary_dim, bool is_neox_style, Tensor query_out, + Tensor key_out) + : num_tokens_{query.size(0)}, + num_heads_{static_cast(query.size(1))}, + num_kv_heads_{static_cast(key.size(1))}, + head_size_{head_size}, + rotary_dim_{rotary_dim}, + is_neox_style_{is_neox_style}, + query_shape_{query.shape()}, + key_shape_{key.shape()}, + cos_sin_cache_shape_{cos_sin_cache.shape()}, + query_out_shape_{query_out.shape()}, + key_out_shape_{key_out.shape()}, + query_strides_{query.strides()}, + key_strides_{key.strides()}, + query_out_strides_{query_out.strides()}, + key_out_strides_{key_out.strides()} { + assert(query.ndim() == 3 && + "`RotaryEmbedding` requires query to be 3D [T, N, D]"); + assert(key.ndim() == 3 && + "`RotaryEmbedding` requires key to be 3D [T, N_kv, D]"); + assert(rotary_dim <= head_size && + "`RotaryEmbedding` requires rotary_dim <= head_size"); + } + + virtual void operator()(const Tensor positions, const Tensor query, + const Tensor key, const Tensor cos_sin_cache, + int64_t head_size, int64_t rotary_dim, + bool is_neox_style, Tensor query_out, + Tensor key_out) const = 0; + + protected: + Tensor::Size num_tokens_{0}; + + int64_t num_heads_{0}; + + int64_t num_kv_heads_{0}; + + int64_t head_size_{0}; + + int64_t rotary_dim_{0}; + + bool is_neox_style_{true}; + + Tensor::Shape query_shape_; + + Tensor::Shape key_shape_; + + Tensor::Shape cos_sin_cache_shape_; + + Tensor::Shape query_out_shape_; + + Tensor::Shape key_out_shape_; + + Tensor::Strides query_strides_; + + Tensor::Strides key_strides_; + + Tensor::Strides query_out_strides_; + + Tensor::Strides key_out_strides_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cuda/causal_softmax/kernel.h b/src/cuda/causal_softmax/kernel.h index 7c7ac87..cffa071 100644 --- a/src/cuda/causal_softmax/kernel.h +++ b/src/cuda/causal_softmax/kernel.h @@ -7,6 +7,7 @@ #include "base/causal_softmax.h" #include "cuda/causal_softmax/kernel.cuh" #include "cuda/kernel_commons.cuh" +#include "cuda/runtime_utils.h" #include "data_type.h" #include "dispatcher.h" diff --git a/src/operator.h b/src/operator.h index 76efd7a..dbe92d7 100644 --- a/src/operator.h +++ b/src/operator.h @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -176,10 +177,10 @@ class Operator : public OperatorBase { auto it{cache.find(key)}; if (it == cache.end()) { - it = cache - .emplace(std::move(key), - make(config, std::forward(args)...)) - .first; + // Pass args as lvalue refs so they remain valid for the `operator()` call + // below. Forwarding rvalue temporaries into `make()` would leave the args + // in a moved-from (empty) state before operator() can use them. + it = cache.emplace(std::move(key), make(config, args...)).first; } auto& op{it->second}; diff --git a/src/pybind11_utils.h b/src/pybind11_utils.h index 0f5e73b..766b6ea 100644 --- a/src/pybind11_utils.h +++ b/src/pybind11_utils.h @@ -116,6 +116,12 @@ inline Tensor TensorFromPybind11Handle(py::handle obj) { return Tensor{data, std::move(shape), dtype, device, std::move(strides)}; } +inline std::optional OptionalTensorFromPybind11Handle( + const std::optional& obj) { + if (!obj.has_value()) return std::nullopt; + return TensorFromPybind11Handle(*obj); +} + } // namespace infini::ops #endif diff --git a/tests/conftest.py b/tests/conftest.py index 44654c3..905e011 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,12 @@ def pytest_addoption(parser): parser.addoption( "--benchmark", action="store_true", help="Run performance benchmarks." ) + parser.addoption( + "--devices", + nargs="+", + default=None, + help="Device(s) to test on (e.g., --devices ascend cpu). Accepts platform names (ascend, nvidia, cambricon, metax, moore, iluvatar) or PyTorch device types (npu, cuda, mlu, musa). Defaults to all available devices.", + ) def pytest_configure(config): @@ -38,11 +44,46 @@ def set_seed_per_test(request): _set_random_seed(_hash(_test_case_path_from_request(request))) +_NPU_UNSUPPORTED_DTYPES = {torch.float64} + +# `torch_npu` does not implement random number generation for `uint16`/`uint32`/`uint64`. +for _bits in (16, 32, 64): + _t = getattr(torch, f"uint{_bits}", None) + if _t is not None: + _NPU_UNSUPPORTED_DTYPES.add(_t) + + +@pytest.fixture(autouse=True) +def skip_unsupported_dtype(request): + if not hasattr(request.node, "callspec"): + return + + params = request.node.callspec.params + + if params.get("device") == "npu" and params.get("dtype") in _NPU_UNSUPPORTED_DTYPES: + pytest.skip(f"{params['dtype']} not supported on Ascend 910B") + + def _set_random_seed(seed): random.seed(seed) torch.manual_seed(seed) +_PLATFORM_TO_TORCH_DEVICE = { + "nvidia": "cuda", + "iluvatar": "cuda", + "metax": "cuda", + "cambricon": "mlu", + "moore": "musa", + "ascend": "npu", +} + + +def _resolve_device(name): + """Map a platform name (e.g., ``ascend``) to a PyTorch device type (e.g., ``npu``).""" + return _PLATFORM_TO_TORCH_DEVICE.get(name, name) + + def pytest_generate_tests(metafunc): already_parametrized = _get_parametrized_args(metafunc) @@ -57,7 +98,17 @@ def pytest_generate_tests(metafunc): ) if "device" in metafunc.fixturenames and "device" not in already_parametrized: - metafunc.parametrize("device", get_available_devices()) + cli_devices = metafunc.config.getoption("--devices") + available = get_available_devices() + + if cli_devices: + devices = tuple( + d for d in (_resolve_device(x) for x in cli_devices) if d in available + ) + else: + devices = () + + metafunc.parametrize("device", devices or available) @pytest.hookimpl(tryfirst=True) diff --git a/tests/test_gemm.py b/tests/test_gemm.py index 40ed35d..3f48562 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, randn_strided +from tests.utils import Payload, get_npu_stream, randn_strided @pytest.mark.auto_act_and_assert @@ -84,16 +84,28 @@ def test_gemm( def _gemm(a, b, alpha, beta, trans_a, trans_b, c, implementation_index=0): - infini.ops.gemm( - a, - b, - alpha, - beta, - trans_a, - trans_b, - c, - implementation_index=implementation_index, - ) + if a.device.type == "npu": + infini.ops.gemm( + a, + b, + alpha, + beta, + trans_a, + trans_b, + c, + stream=get_npu_stream(a), + ) + else: + infini.ops.gemm( + a, + b, + alpha, + beta, + trans_a, + trans_b, + c, + implementation_index=implementation_index, + ) return c diff --git a/tests/utils.py b/tests/utils.py index aa4ee42..8412cd6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -32,12 +32,18 @@ def get_available_devices(): if hasattr(torch, "musa") and torch.musa.is_available(): devices.append("musa") + if hasattr(torch, "npu") and torch.npu.is_available(): + devices.append("npu") + return tuple(devices) with contextlib.suppress(ImportError, ModuleNotFoundError): import torch_mlu # noqa: F401 +with contextlib.suppress(ImportError, ModuleNotFoundError): + import torch_npu # noqa: F401 + def empty_strided(shape, strides, *, dtype=None, device=None): if strides is None: @@ -76,6 +82,14 @@ def randint_strided(low, high, shape, strides, *, dtype=None, device=None): return output +def get_npu_stream(tensor): + """Return the current NPU stream handle for `tensor`, or 0 on other devices.""" + if tensor.device.type != "npu": + return 0 + + return torch.npu.current_stream().npu_stream + + def clone_strided(input): output = empty_strided( input.size(), input.stride(), dtype=input.dtype, device=input.device