From 1e78a0243aecbef5deed9009669de9341f8b5775 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Wed, 8 Apr 2026 16:48:16 +0800 Subject: [PATCH 1/8] feat(ascend): add 9 Ascend operator kernels Add, RmsNorm, Swiglu, Matmul, CausalSoftmax, AddRmsNorm, ReshapeAndCache, RotaryEmbedding, FlashAttention. --- src/ascend/add/kernel.h | 58 +++ src/ascend/add_rms_norm/kernel.h | 64 ++++ src/ascend/causal_softmax/kernel.h | 127 +++++++ src/ascend/flash_attention/kernel.h | 321 ++++++++++++++++ src/ascend/matmul/kernel.h | 44 +++ src/ascend/reshape_and_cache/kernel.h | 71 ++++ src/ascend/rms_norm/kernel.h | 62 ++++ src/ascend/rotary_embedding/kernel.h | 505 ++++++++++++++++++++++++++ src/ascend/swiglu/kernel.h | 70 ++++ src/base/rotary_embedding.h | 4 +- 10 files changed, 1324 insertions(+), 2 deletions(-) create mode 100644 src/ascend/add/kernel.h create mode 100644 src/ascend/add_rms_norm/kernel.h create mode 100644 src/ascend/causal_softmax/kernel.h create mode 100644 src/ascend/flash_attention/kernel.h create mode 100644 src/ascend/matmul/kernel.h create mode 100644 src/ascend/reshape_and_cache/kernel.h create mode 100644 src/ascend/rms_norm/kernel.h create mode 100644 src/ascend/rotary_embedding/kernel.h create mode 100644 src/ascend/swiglu/kernel.h diff --git a/src/ascend/add/kernel.h b/src/ascend/add/kernel.h new file mode 100644 index 0000000..e81f9bd --- /dev/null +++ b/src/ascend/add/kernel.h @@ -0,0 +1,58 @@ +#ifndef INFINI_OPS_ASCEND_ADD_KERNEL_H_ +#define INFINI_OPS_ASCEND_ADD_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_add.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/add.h" +#include "data_type.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Add { + public: + Operator(const Tensor input, const Tensor other, Tensor out) + : Add(input, other, out) { + // aclCreateScalar stores the pointer rather than copying the value, so + // alpha_storage_* must remain alive for the lifetime of alpha_. + // The alpha scalar type must match the tensor dtype: use int64 for integer + // dtypes and float for floating-point dtypes. + if (ascend::isIntegerDtype(input.dtype())) { + alpha_ = aclCreateScalar(&alpha_int_storage_, ACL_INT64); + } else { + alpha_ = aclCreateScalar(&alpha_float_storage_, ACL_FLOAT); + } + } + + ~Operator() { aclDestroyScalar(alpha_); } + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + auto stream = static_cast(stream_); + auto t_in = ascend::buildAclTensor(input); + auto t_oth = ascend::buildAclTensor(other); + auto t_out = ascend::buildAclTensor(out); + uint64_t ws_needed = 0; + aclOpExecutor* executor = nullptr; + aclnnAddGetWorkspaceSize(t_in, t_oth, alpha_, t_out, &ws_needed, &executor); + auto& arena = ascend::workspacePool().ensure(stream, ws_needed); + aclnnAdd(arena.buf, ws_needed, executor, stream); + aclDestroyTensor(t_in); + aclDestroyTensor(t_oth); + aclDestroyTensor(t_out); + } + + private: + float alpha_float_storage_ = + 1.0f; // stable address for aclCreateScalar (float) + int64_t alpha_int_storage_ = 1; // stable address for aclCreateScalar (int) + aclScalar* alpha_ = nullptr; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/add_rms_norm/kernel.h b/src/ascend/add_rms_norm/kernel.h new file mode 100644 index 0000000..28ae702 --- /dev/null +++ b/src/ascend/add_rms_norm/kernel.h @@ -0,0 +1,64 @@ +#ifndef INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_add_rms_norm.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/add_rms_norm.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public AddRmsNorm { + public: + Operator(const Tensor x1, const Tensor x2, const Tensor gamma, float eps, + Tensor y_out, Tensor x_out) + : AddRmsNorm(x1, x2, gamma, eps, y_out, x_out) { + // aclnnAddRmsNorm writes rstd as a required side output. + // Allocate a persistent device buffer for it. + size_t rstd_bytes = batch_size_ * nhead_ * sizeof(float); + aclrtMalloc(&rstd_data_, rstd_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + } + + ~Operator() { + if (rstd_data_) aclrtFree(rstd_data_); + } + + void operator()(const Tensor x1, const Tensor x2, const Tensor gamma, + float eps, Tensor y_out, Tensor x_out) const override { + auto t_x1 = ascend::buildAclTensor(x1); + auto t_x2 = ascend::buildAclTensor(x2); + auto t_gamma = ascend::buildAclTensor(gamma); + auto t_y_out = ascend::buildAclTensor(y_out); + auto t_x_out = ascend::buildAclTensor(x_out); + // rstd is always float32 regardless of input dtype. + auto t_rstd = aclCreateTensor(rstd_shape_.data(), 2, ACL_FLOAT, + /*strides=*/nullptr, 0, ACL_FORMAT_ND, + rstd_shape_.data(), 2, rstd_data_); + uint64_t ws_needed = 0; + aclOpExecutor* executor = nullptr; + aclnnAddRmsNormGetWorkspaceSize(t_x1, t_x2, t_gamma, eps, t_y_out, t_rstd, + t_x_out, &ws_needed, &executor); + auto stream = static_cast(stream_); + auto& arena = ascend::workspacePool().ensure(stream, ws_needed); + aclnnAddRmsNorm(arena.buf, ws_needed, executor, stream); + aclDestroyTensor(t_x1); + aclDestroyTensor(t_x2); + aclDestroyTensor(t_gamma); + aclDestroyTensor(t_y_out); + aclDestroyTensor(t_rstd); + aclDestroyTensor(t_x_out); + } + + private: + void* rstd_data_ = nullptr; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/causal_softmax/kernel.h b/src/ascend/causal_softmax/kernel.h new file mode 100644 index 0000000..5883c42 --- /dev/null +++ b/src/ascend/causal_softmax/kernel.h @@ -0,0 +1,127 @@ +#ifndef INFINI_OPS_ASCEND_CAUSAL_SOFTMAX_KERNEL_H_ +#define INFINI_OPS_ASCEND_CAUSAL_SOFTMAX_KERNEL_H_ + +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_copy.h" +#include "aclnn_masked_fill_scalar.h" +#include "aclnn_softmax.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/causal_softmax.h" +#include "data_type.h" +#include "operator.h" + +namespace infini::ops { + +// Implements causal softmax via three ACLNN calls: +// 1. InplaceCopy(temp, input) — stride-aware copy to contiguous temp +// buffer. +// 2. InplaceMaskedFillScalar(temp, mask, -inf) — apply upper-triangle mask. +// 3. Softmax(temp, dim=-1, out) — softmax over the last dimension. +// +// The boolean causal mask is pre-computed and uploaded to device once in the +// constructor. Its shape (seq_len, total_seq_len) broadcasts over the batch. +template <> +class Operator : public CausalSoftmax { + public: + Operator(const Tensor input, Tensor out) : CausalSoftmax(input, out) { + // Contiguous temp buffer with the same element count as input. + size_t n_elems = input.numel(); + size_t elem_bytes = kDataTypeToSize.at(dtype_); + aclrtMalloc(&temp_buf_, n_elems * elem_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + // Build a contiguous Tensor descriptor pointing to temp_buf_. + Tensor temp_t{temp_buf_, input.shape(), input.dtype(), input.device()}; + + // Causal mask: mask[i][j] = 1 when position j must be masked for query i. + // Shape (seq_len, total_seq_len) – broadcasts over the batch dimension. + size_t mask_elems = seq_len_ * total_seq_len_; + std::vector mask_host(mask_elems, 0); + + for (size_t i = 0; i < seq_len_; ++i) { + auto vis_end = static_cast(total_seq_len_ - seq_len_ + i); + + for (auto j = vis_end + 1; j < static_cast(total_seq_len_); + ++j) { + mask_host[i * total_seq_len_ + j] = 1; + } + } + + aclrtMalloc(&mask_buf_, mask_elems, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMemcpy(mask_buf_, mask_elems, mask_host.data(), mask_elems, + ACL_MEMCPY_HOST_TO_DEVICE); + + std::vector mshape = {static_cast(seq_len_), + static_cast(total_seq_len_)}; + std::vector mstrides = {static_cast(total_seq_len_), 1}; + mask_tensor_ = aclCreateTensor(mshape.data(), mshape.size(), ACL_BOOL, + mstrides.data(), 0, ACL_FORMAT_ND, + mshape.data(), mshape.size(), mask_buf_); + + // Scalar -inf for the masked-fill step. aclCreateScalar stores the pointer + // rather than copying, so neg_inf_storage_ must stay alive with the object. + neg_inf_ = aclCreateScalar(&neg_inf_storage_, ACL_FLOAT); + // Workspaces are allocated lazily on first operator() call. + } + + ~Operator() { + aclrtFree(temp_buf_); + aclrtFree(mask_buf_); + aclDestroyTensor(mask_tensor_); + aclDestroyScalar(neg_inf_); + } + + void operator()(const Tensor input, Tensor out) const override { + Tensor temp_t{temp_buf_, input.shape(), input.dtype(), input.device()}; + auto t_in = ascend::buildAclTensor(input); + auto t_temp = ascend::buildAclTensor(temp_t); + auto t_out = ascend::buildAclTensor(out); + auto stream = static_cast(stream_); + + uint64_t ws_needed = 0; + aclOpExecutor* exec = nullptr; + + // Step 1: copy input (possibly non-contiguous) into contiguous temp. + aclnnInplaceCopyGetWorkspaceSize(t_temp, t_in, &ws_needed, &exec); + auto& copy_arena = ascend::workspacePool().ensure(stream, ws_needed); + uint64_t copy_ws = ws_needed; + aclnnInplaceCopy(copy_arena.buf, copy_ws, exec, stream); + + // Step 2: mask upper-triangle positions with -inf in-place. + ws_needed = 0; + exec = nullptr; + aclnnInplaceMaskedFillScalarGetWorkspaceSize(t_temp, mask_tensor_, neg_inf_, + &ws_needed, &exec); + auto& fill_arena = ascend::workspacePool().ensure(stream, ws_needed); + uint64_t fill_ws = ws_needed; + aclnnInplaceMaskedFillScalar(fill_arena.buf, fill_ws, exec, stream); + + // Step 3: softmax over the last dimension → out. + ws_needed = 0; + exec = nullptr; + constexpr int64_t kLastDim = -1; + aclnnSoftmaxGetWorkspaceSize(t_temp, kLastDim, t_out, &ws_needed, &exec); + auto& softmax_arena = ascend::workspacePool().ensure(stream, ws_needed); + uint64_t softmax_ws = ws_needed; + aclnnSoftmax(softmax_arena.buf, softmax_ws, exec, stream); + + aclDestroyTensor(t_in); + aclDestroyTensor(t_temp); + aclDestroyTensor(t_out); + } + + private: + float neg_inf_storage_ = -std::numeric_limits::infinity(); + void* temp_buf_ = nullptr; + void* mask_buf_ = nullptr; + aclTensor* mask_tensor_ = nullptr; + aclScalar* neg_inf_ = nullptr; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/flash_attention/kernel.h b/src/ascend/flash_attention/kernel.h new file mode 100644 index 0000000..3b82e53 --- /dev/null +++ b/src/ascend/flash_attention/kernel.h @@ -0,0 +1,321 @@ +#ifndef INFINI_OPS_ASCEND_FLASH_ATTENTION_KERNEL_H_ +#define INFINI_OPS_ASCEND_FLASH_ATTENTION_KERNEL_H_ + +#include +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_fused_infer_attention_score_v4.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/flash_attention.h" +#include "operator.h" + +namespace infini::ops { + +namespace detail { + +// Build an aclTensor with a different view shape/stride but the same data +// pointer. +inline aclTensor* reshapeView(const Tensor& t, + const std::vector& new_shape, + const std::vector& new_strides) { + int64_t storage_elems = 1; + for (size_t i = 0; i < new_shape.size(); ++i) { + if (new_shape[i] == 0) { + storage_elems = 0; + break; + } + if (new_strides[i] > 0 && new_shape[i] > 1) { + storage_elems += static_cast(new_shape[i] - 1) * new_strides[i]; + } + } + std::vector storage_shape = {storage_elems}; + return aclCreateTensor( + new_shape.data(), static_cast(new_shape.size()), + ascend::toAclDtype(t.dtype()), new_strides.data(), 0, ACL_FORMAT_ND, + storage_shape.data(), static_cast(storage_shape.size()), + const_cast(t.data())); +} + +// Extract cu_seqlens differences to a host aclIntArray. +// cu_seqlens = [0, s1, s1+s2, ...] -> per_seq_lens = [s1, s2, ...]. +// Used by paged decode (actualSeqLengthsKv = per-sequence KV lengths). +inline aclIntArray* extractSeqLengths(const Tensor& cu_seqlens, + aclrtStream stream) { + auto n = cu_seqlens.numel(); + std::vector cu_host(n); + aclrtMemcpyAsync(cu_host.data(), n * sizeof(int64_t), cu_seqlens.data(), + n * sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST, stream); + aclrtSynchronizeStream(stream); + + std::vector lengths(n - 1); + for (size_t i = 0; i < lengths.size(); ++i) { + lengths[i] = cu_host[i + 1] - cu_host[i]; + } + return aclCreateIntArray(lengths.data(), + static_cast(lengths.size())); +} + +// Extract cumulative end positions from cu_seqlens to a host aclIntArray. +// cu_seqlens = [0, s1, s1+s2, ...] -> cum_lens = [s1, s1+s2, ...]. +// FIA V4 TND varlen uses cumulative end positions, matching the vllm-ascend +// convention for npu_fused_infer_attention_score actual_seq_lengths. +inline aclIntArray* cumSeqLengths(const Tensor& cu_seqlens, + aclrtStream stream) { + auto n = cu_seqlens.numel(); + std::vector cu_host(n); + aclrtMemcpyAsync(cu_host.data(), n * sizeof(int64_t), cu_seqlens.data(), + n * sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST, stream); + aclrtSynchronizeStream(stream); + + // Skip the leading 0; return [s1, s1+s2, ...]. + return aclCreateIntArray(cu_host.data() + 1, static_cast(n - 1)); +} + +// Allocate a 2048x2048 lower-triangular UINT8 causal mask on device. +// Required for sparseMode >= 2. +inline aclTensor* makeCausalMask(void** mask_buf, aclrtStream stream) { + constexpr int64_t kMaskDim = 2048; + const int64_t mask_elems = kMaskDim * kMaskDim; + const size_t mask_bytes = static_cast(mask_elems); // uint8_t + + aclrtMalloc(mask_buf, mask_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + std::vector host_mask(mask_elems); + for (int64_t r = 0; r < kMaskDim; ++r) { + for (int64_t c = 0; c < kMaskDim; ++c) { + // 1 = masked out (upper triangle); 0 = attend (lower triangle). + host_mask[r * kMaskDim + c] = (c > r) ? 1 : 0; + } + } + aclrtMemcpyAsync(*mask_buf, mask_bytes, host_mask.data(), mask_bytes, + ACL_MEMCPY_HOST_TO_DEVICE, stream); + aclrtSynchronizeStream(stream); + + std::vector mask_shape = {kMaskDim, kMaskDim}; + std::vector mask_strides = {kMaskDim, 1}; + std::vector mask_storage = {mask_elems}; + return aclCreateTensor(mask_shape.data(), 2, ACL_UINT8, mask_strides.data(), + 0, ACL_FORMAT_ND, mask_storage.data(), 1, *mask_buf); +} + +} // namespace detail + +template <> +class Operator : public FlashAttention { + public: + using FlashAttention::FlashAttention; + + void operator()(const Tensor query, const Tensor key, const Tensor value, + std::optional cu_seqlens_q, + std::optional cu_seqlens_kv, + std::optional block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + bool causal, int64_t window_left, int64_t window_right, + int64_t block_size, Tensor output) const override { + auto stream = static_cast(stream_); + const bool paged = block_table.has_value() && block_size > 0; + + // Map causal + window_left/right to FIA sparse_mode / preTokens / + // nextTokens. + // + // causal=true, window_left<0 -> sparse_mode=3 (full causal) + // causal=true, window_left>=0 -> sparse_mode=4 (sliding + // window causal) causal=false -> sparse_mode=0 + // (no mask) + // + // sparse_mode is ignored by FIA when Q_S=1 (paged decode); effective_sparse + // is set to 0 in that path to avoid allocating the unnecessary causal mask. + int64_t sparse_mode; + int64_t pre_tokens = 2147483647; + int64_t next_tokens = 2147483647; + if (causal) { + if (window_left >= 0) { + sparse_mode = 4; // band: sliding window causal + pre_tokens = window_left; + next_tokens = 0; + } else { + sparse_mode = 3; // rightDownCausal: full causal, pre/next ignored + next_tokens = 0; + } + } else { + sparse_mode = 0; + if (window_left >= 0) pre_tokens = window_left; + if (window_right >= 0) next_tokens = window_right; + } + + if (!paged) { + // --- Prefill (single- or multi-sequence) --- + // V4 TND: query/key/value passed as token-packed [T, N, D]; per-sequence + // lengths are derived from cu_seqlens. Single fused call for all + // sequences, equivalent to flash_attn_varlen_func on CUDA. + int64_t T = query.size(0); + + // V4 TND varlen uses cumulative end positions [s1, s1+s2, ...]. + // For single-seq (no cu_seqlens), [T] is both per-seq and cumulative. + aclIntArray* seq_q = + cu_seqlens_q.has_value() + ? detail::cumSeqLengths(cu_seqlens_q.value(), stream) + : aclCreateIntArray(&T, 1); + aclIntArray* seq_kv = + cu_seqlens_kv.has_value() + ? detail::cumSeqLengths(cu_seqlens_kv.value(), stream) + : aclCreateIntArray(&T, 1); + + aclTensor* t_q = ascend::buildAclTensor(query); + aclTensor* t_k = ascend::buildAclTensor(key); + aclTensor* t_v = ascend::buildAclTensor(value); + aclTensor* t_out = ascend::buildAclTensor(output); + + const aclTensor* k_arr[] = {t_k}; + const aclTensor* v_arr[] = {t_v}; + aclTensorList* key_list = aclCreateTensorList(k_arr, 1); + aclTensorList* val_list = aclCreateTensorList(v_arr, 1); + + // sparseMode 2/3/4 require a 2048x2048 lower-triangular causal mask. + aclTensor* atten_mask = nullptr; + void* mask_buf = nullptr; + if (sparse_mode >= 2) { + atten_mask = detail::makeCausalMask(&mask_buf, stream); + } + + uint64_t ws_needed = 0; + aclOpExecutor* executor = nullptr; + // Parameter order: query, key, value, + // pseShift, attenMask, actualSeqLengths, actualSeqLengthsKv, + // deqScale1, quantScale1, deqScale2, quantScale2, quantOffset2, + // antiquantScale, antiquantOffset, + // blockTable, queryPaddingSize, kvPaddingSize, + // keyAntiquantScale, keyAntiquantOffset, + // valueAntiquantScale, valueAntiquantOffset, + // keySharedPrefix, valueSharedPrefix, actualSharedPrefixLen, + // queryRope, keyRope, keyRopeAntiquantScale, + // dequantScaleQuery, learnableSink, + // numHeads, scaleValue, preTokens, nextTokens, inputLayout, + // numKeyValueHeads, sparseMode, innerPrecise, blockSize, + // antiquantMode, softmaxLseFlag, + // keyAntiquantMode, valueAntiquantMode, queryQuantMode, + // attentionOut, softmaxLse, workspaceSize, executor + aclError gws = aclnnFusedInferAttentionScoreV4GetWorkspaceSize( + t_q, key_list, val_list, + nullptr, // pseShift + atten_mask, // attenMask + seq_q, // actualSeqLengths + seq_kv, // actualSeqLengthsKv + nullptr, nullptr, nullptr, nullptr, + nullptr, // deqScale1..quantOffset2 + nullptr, nullptr, // antiquantScale, antiquantOffset + nullptr, // blockTable + nullptr, nullptr, // queryPaddingSize, kvPaddingSize + nullptr, nullptr, nullptr, + nullptr, // key/value antiquant scale/offset + nullptr, nullptr, + nullptr, // keySharedPrefix, valueSharedPrefix, actualSharedPrefixLen + nullptr, nullptr, + nullptr, // queryRope, keyRope, keyRopeAntiquantScale + nullptr, nullptr, // dequantScaleQuery, learnableSink + num_heads, scale, pre_tokens, next_tokens, const_cast("TND"), + num_kv_heads, sparse_mode, + 0, // innerPrecise + 0, // blockSize (unused for prefill) + 0, false, // antiquantMode, softmaxLseFlag + 0, 0, 0, // keyAntiquantMode, valueAntiquantMode, queryQuantMode + t_out, nullptr, &ws_needed, &executor); + assert( + gws == ACL_SUCCESS && + "aclnnFusedInferAttentionScoreV4GetWorkspaceSize failed (prefill)"); + + auto& arena = ascend::workspacePool().ensure(stream, ws_needed); + aclError ret = aclnnFusedInferAttentionScoreV4(arena.buf, ws_needed, + executor, stream); + assert(ret == ACL_SUCCESS && + "aclnnFusedInferAttentionScoreV4 failed (prefill)"); + + aclDestroyTensor(t_q); + aclDestroyTensor(t_out); + aclDestroyTensorList(key_list); + aclDestroyTensorList(val_list); + aclDestroyIntArray(seq_q); + aclDestroyIntArray(seq_kv); + if (atten_mask) aclDestroyTensor(atten_mask); + if (mask_buf) aclrtFree(mask_buf); + return; + } + + // --- Paged decode --- + // V4 BNSD: reshape query/output [B, N, D] -> [B, N, 1, D]. + // KV cache [num_blocks, block_size, N_kv, D] flattened to + // [num_blocks, block_size, N_kv*D] (zero-copy, FIA BSH kv format). + assert(cu_seqlens_kv.has_value() && + "`FlashAttention` paged decode requires `cu_seqlens_kv`"); + + const int64_t N = query.size(1); + const int64_t D = query.size(2); + const int64_t B = query.size(0); + const int64_t nb = key.size(0); + const int64_t bsz = key.size(1); + const int64_t NkvD = key.size(2) * key.size(3); + + std::vector bnsd_sh = {B, N, 1, D}; + std::vector bnsd_st = {N * D, D, D, 1}; + aclTensor* t_query = detail::reshapeView(query, bnsd_sh, bnsd_st); + aclTensor* t_output = detail::reshapeView(output, bnsd_sh, bnsd_st); + + std::vector kv_sh = {nb, bsz, NkvD}; + std::vector kv_st = {bsz * NkvD, NkvD, 1}; + aclTensor* t_key = detail::reshapeView(key, kv_sh, kv_st); + aclTensor* t_value = detail::reshapeView(value, kv_sh, kv_st); + + aclIntArray* seq_kv = + detail::extractSeqLengths(cu_seqlens_kv.value(), stream); + aclTensor* t_block_table = ascend::buildAclTensor(block_table.value()); + + const aclTensor* k_arr[] = {t_key}; + const aclTensor* v_arr[] = {t_value}; + aclTensorList* key_list = aclCreateTensorList(k_arr, 1); + aclTensorList* val_list = aclCreateTensorList(v_arr, 1); + + uint64_t ws_needed = 0; + aclOpExecutor* executor = nullptr; + aclError gws = aclnnFusedInferAttentionScoreV4GetWorkspaceSize( + t_query, key_list, val_list, + nullptr, // pseShift + nullptr, // attenMask (sparseMode ignored for Q_S=1) + nullptr, // actualSeqLengths (ignored for Q_S=1) + seq_kv, // actualSeqLengthsKv (mandatory for paged) + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + t_block_table, // blockTable + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, num_heads, scale, + static_cast(2147483647), static_cast(2147483647), + const_cast("BNSD"), num_kv_heads, + 0, // sparseMode=0 (ignored for Q_S=1) + 0, // innerPrecise + block_size, // blockSize + 0, false, // antiquantMode, softmaxLseFlag + 0, 0, 0, // keyAntiquantMode, valueAntiquantMode, queryQuantMode + t_output, nullptr, &ws_needed, &executor); + assert(gws == ACL_SUCCESS && + "aclnnFusedInferAttentionScoreV4GetWorkspaceSize failed (decode)"); + + auto& arena = ascend::workspacePool().ensure(stream, ws_needed); + aclError ret = + aclnnFusedInferAttentionScoreV4(arena.buf, ws_needed, executor, stream); + assert(ret == ACL_SUCCESS && + "aclnnFusedInferAttentionScoreV4 failed (decode)"); + + aclDestroyTensor(t_query); + aclDestroyTensor(t_output); + aclDestroyTensorList(key_list); + aclDestroyTensorList(val_list); + aclDestroyTensor(t_block_table); + aclDestroyIntArray(seq_kv); + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/matmul/kernel.h b/src/ascend/matmul/kernel.h new file mode 100644 index 0000000..4070634 --- /dev/null +++ b/src/ascend/matmul/kernel.h @@ -0,0 +1,44 @@ +#ifndef INFINI_OPS_ASCEND_MATMUL_KERNEL_H_ +#define INFINI_OPS_ASCEND_MATMUL_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_matmul.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/matmul.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Matmul { + public: + Operator(const Tensor a, const Tensor b, Tensor c, bool trans_a, bool trans_b) + : Matmul(a, b, c, trans_a, trans_b) {} + + void operator()(const Tensor a, const Tensor b, Tensor c, bool trans_a, + bool trans_b) const override { + auto stream = static_cast(stream_); + auto t_a = ascend::buildAclTensor(a, trans_a); + auto t_b = ascend::buildAclTensor(b, trans_b); + auto t_out = ascend::buildAclTensor(c); + + uint64_t ws_needed = 0; + aclOpExecutor* executor = nullptr; + // cube_math_type = 1: allow fp16 accumulation. + int8_t cube_math_type = 1; + aclnnMatmulGetWorkspaceSize(t_a, t_b, t_out, cube_math_type, &ws_needed, + &executor); + auto& arena = ascend::workspacePool().ensure(stream, ws_needed); + aclnnMatmul(arena.buf, ws_needed, executor, stream); + + aclDestroyTensor(t_a); + aclDestroyTensor(t_b); + aclDestroyTensor(t_out); + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/reshape_and_cache/kernel.h b/src/ascend/reshape_and_cache/kernel.h new file mode 100644 index 0000000..609a1ee --- /dev/null +++ b/src/ascend/reshape_and_cache/kernel.h @@ -0,0 +1,71 @@ +#ifndef INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_H_ +#define INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_H_ + +#include +#include +#include + +#include "acl/acl.h" +#include "ascend/device_.h" +#include "base/reshape_and_cache.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator + : public ReshapeAndCache { + public: + using ReshapeAndCache::ReshapeAndCache; + + void operator()(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, + Tensor kv_cache_out) const override { + auto stream = static_cast(stream_); + + // Copy slot_mapping to host for address computation. + auto num_tokens = static_cast(num_tokens_); + std::vector slots(num_tokens); + aclrtMemcpyAsync(slots.data(), num_tokens * sizeof(int64_t), + slot_mapping.data(), num_tokens * sizeof(int64_t), + ACL_MEMCPY_DEVICE_TO_HOST, stream); + aclrtSynchronizeStream(stream); + + auto bs = static_cast(block_size_); + auto row_bytes = static_cast(num_kv_heads_ * head_size_) * + kDataTypeToSize.at(key.dtype()); + + // kv_cache layout: [2, num_blocks, block_size, num_kv_heads, head_size] + // kv_cache[0] = key cache, kv_cache[1] = value cache. + // Stride for the first dim (K vs V): kv_cache.stride(0). + auto kv_stride0 = static_cast(kv_cache_out.stride(0)); + + for (int64_t i = 0; i < num_tokens; ++i) { + auto slot = slots[i]; + if (slot < 0) continue; // Padding token — skip. + auto block_idx = slot / bs; + auto offset = slot % bs; + + auto cache_offset = (block_idx * kv_cache_out.stride(1) + + offset * kv_cache_out.stride(2)) * + kv_cache_out.element_size(); + + auto* k_src = static_cast(key.data()) + + i * key.stride(0) * key.element_size(); + auto* k_dst = static_cast(kv_cache_out.data()) + cache_offset; + aclrtMemcpyAsync(k_dst, row_bytes, k_src, row_bytes, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + + auto* v_src = static_cast(value.data()) + + i * value.stride(0) * value.element_size(); + auto* v_dst = static_cast(kv_cache_out.data()) + + kv_stride0 * kv_cache_out.element_size() + cache_offset; + aclrtMemcpyAsync(v_dst, row_bytes, v_src, row_bytes, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/rms_norm/kernel.h b/src/ascend/rms_norm/kernel.h new file mode 100644 index 0000000..9eef1bb --- /dev/null +++ b/src/ascend/rms_norm/kernel.h @@ -0,0 +1,62 @@ +#ifndef INFINI_OPS_ASCEND_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_ASCEND_RMS_NORM_KERNEL_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_rms_norm.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/rms_norm.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public RmsNorm { + public: + Operator(const Tensor input, const Tensor weight, float eps, Tensor out) + : RmsNorm(input, weight, eps, out) { + // aclnnRmsNorm writes rstd as a required side output. + // Allocate a persistent device buffer for it. + rstd_shape_ = {static_cast(batch_size_), + static_cast(nhead_)}; + size_t rstd_bytes = batch_size_ * nhead_ * sizeof(float); + aclrtMalloc(&rstd_data_, rstd_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + } + + ~Operator() { + if (rstd_data_) aclrtFree(rstd_data_); + } + + void operator()(const Tensor input, const Tensor weight, float eps, + Tensor out) const override { + auto t_in = ascend::buildAclTensor(input); + auto t_weight = ascend::buildAclTensor(weight); + auto t_out = ascend::buildAclTensor(out); + // rstd is always float32 regardless of input dtype. + auto t_rstd = aclCreateTensor(rstd_shape_.data(), 2, ACL_FLOAT, + /*strides=*/nullptr, 0, ACL_FORMAT_ND, + rstd_shape_.data(), 2, rstd_data_); + uint64_t ws_needed = 0; + aclOpExecutor* executor = nullptr; + aclnnRmsNormGetWorkspaceSize(t_in, t_weight, eps, t_out, t_rstd, &ws_needed, + &executor); + auto stream = static_cast(stream_); + auto& arena = ascend::workspacePool().ensure(stream, ws_needed); + aclnnRmsNorm(arena.buf, ws_needed, executor, stream); + aclDestroyTensor(t_in); + aclDestroyTensor(t_weight); + aclDestroyTensor(t_out); + aclDestroyTensor(t_rstd); + } + + private: + std::vector rstd_shape_; + void* rstd_data_ = nullptr; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/rotary_embedding/kernel.h b/src/ascend/rotary_embedding/kernel.h new file mode 100644 index 0000000..5c3da01 --- /dev/null +++ b/src/ascend/rotary_embedding/kernel.h @@ -0,0 +1,505 @@ +#ifndef INFINI_OPS_ASCEND_ROTARY_EMBEDDING_KERNEL_H_ +#define INFINI_OPS_ASCEND_ROTARY_EMBEDDING_KERNEL_H_ + +#include +#include +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_apply_rotary_pos_emb_v2.h" +#include "aclnnop/aclnn_index_select.h" +#include "aclnnop/aclnn_rotary_position_embedding.h" +#include "ascend/data_type_.h" +#include "ascend/workspace_pool_.h" +#include "base/rotary_embedding.h" +#include "operator.h" + +namespace infini::ops { + +// aclnnApplyRotaryPosEmbV2 hardware constraints on Atlas A2/A3: +// - rotaryMode "half" only (neox style) +// - D (last dim of queryRef) must be 64 or 128 +// - bfloat16 only (float16 accumulates with ~1 ULP error that exceeds +// atol=0.001 in tests; bfloat16 passes with atol=0.005) +// +// Use V2 when all three hold; fall back to V1 otherwise. +static bool use_rope_v2(int64_t D, bool is_neox, DataType dtype) { + return is_neox && (D == 64 || D == 128) && dtype == DataType::kBFloat16; +} + +template <> +class Operator + : public RotaryEmbedding { + public: + Operator(const Tensor positions, const Tensor query, const Tensor key, + const Tensor cos_sin_cache, int64_t head_size, int64_t rotary_dim, + bool is_neox_style, Tensor query_out, Tensor key_out) + : RotaryEmbedding(positions, query, key, cos_sin_cache, head_size, + rotary_dim, is_neox_style, query_out, key_out) { + const int64_t max_seq_len = cos_sin_cache.size(0); + const int64_t R = rotary_dim_; + const int64_t half_R = R / 2; + cache_elem_size_ = cos_sin_cache.element_size(); + + // Copy raw cache to host for pre-expansion (one-time cost). + size_t raw_bytes = static_cast(max_seq_len * R) * cache_elem_size_; + std::vector cache_host(raw_bytes); + aclrtMemcpy(cache_host.data(), raw_bytes, cos_sin_cache.data(), raw_bytes, + ACL_MEMCPY_DEVICE_TO_HOST); + + // Pre-expand into separate cos/sin tables with duplicated values. + // After expansion each row is R-wide: + // neox: cos = [c0..c_{hR-1}, c0..c_{hR-1}] (first half repeated) + // interleave: cos = [c0,c0, c1,c1, ..., c_{hR-1},c_{hR-1}] + // Same pattern for sin. + table_bytes_ = raw_bytes; + std::vector cos_table_host(table_bytes_); + std::vector sin_table_host(table_bytes_); + + for (int64_t p = 0; p < max_seq_len; ++p) { + if (is_neox_style_) { + for (int64_t j = 0; j < half_R; ++j) { + const uint8_t* c_src = + cache_host.data() + + static_cast(p * R + j) * cache_elem_size_; + const uint8_t* s_src = + cache_host.data() + + static_cast(p * R + half_R + j) * cache_elem_size_; + auto* cos_dst = cos_table_host.data(); + auto* sin_dst = sin_table_host.data(); + std::memcpy( + cos_dst + static_cast(p * R + j) * cache_elem_size_, + c_src, cache_elem_size_); + std::memcpy(cos_dst + static_cast(p * R + half_R + j) * + cache_elem_size_, + c_src, cache_elem_size_); + std::memcpy( + sin_dst + static_cast(p * R + j) * cache_elem_size_, + s_src, cache_elem_size_); + std::memcpy(sin_dst + static_cast(p * R + half_R + j) * + cache_elem_size_, + s_src, cache_elem_size_); + } + } else { + for (int64_t j = 0; j < half_R; ++j) { + const uint8_t* c_src = + cache_host.data() + + static_cast(p * R + j) * cache_elem_size_; + const uint8_t* s_src = + cache_host.data() + + static_cast(p * R + half_R + j) * cache_elem_size_; + auto* cos_dst = cos_table_host.data(); + auto* sin_dst = sin_table_host.data(); + std::memcpy( + cos_dst + static_cast(p * R + 2 * j) * cache_elem_size_, + c_src, cache_elem_size_); + std::memcpy(cos_dst + static_cast(p * R + 2 * j + 1) * + cache_elem_size_, + c_src, cache_elem_size_); + std::memcpy( + sin_dst + static_cast(p * R + 2 * j) * cache_elem_size_, + s_src, cache_elem_size_); + std::memcpy(sin_dst + static_cast(p * R + 2 * j + 1) * + cache_elem_size_, + s_src, cache_elem_size_); + } + } + } + + // Upload expanded tables to device (one-time). + aclrtMalloc(&cos_table_dev_, table_bytes_, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMalloc(&sin_table_dev_, table_bytes_, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMemcpy(cos_table_dev_, table_bytes_, cos_table_host.data(), + table_bytes_, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(sin_table_dev_, table_bytes_, sin_table_host.data(), + table_bytes_, ACL_MEMCPY_HOST_TO_DEVICE); + + const int64_t T = num_tokens_; + const int64_t Nq = num_heads_; + const int64_t Nkv = num_kv_heads_; + const int64_t D = head_size_; + const bool v2 = use_rope_v2(R, is_neox_style_, query.dtype()); + use_v2_ = v2; + + // Gathered output buffers [T, R] — filled by aclnnIndexSelect at runtime. + gathered_cs_bytes_ = static_cast(T * R) * cache_elem_size_; + aclrtMalloc(&cos_dev_, gathered_cs_bytes_, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMalloc(&sin_dev_, gathered_cs_bytes_, ACL_MEM_MALLOC_NORMAL_ONLY); + + // Scratch for partial-rotation (R < D) — used by both V1 and V2. + if (R < D) { + size_t q_rot_bytes = static_cast(T * Nq * R) * cache_elem_size_; + size_t k_rot_bytes = static_cast(T * Nkv * R) * cache_elem_size_; + aclrtMalloc(&q_rot_dev_, q_rot_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMalloc(&k_rot_dev_, k_rot_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + if (!v2) { + aclrtMalloc(&q_out_rot_dev_, q_rot_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMalloc(&k_out_rot_dev_, k_rot_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + } + } + } + + ~Operator() { + if (cos_table_dev_) aclrtFree(cos_table_dev_); + if (sin_table_dev_) aclrtFree(sin_table_dev_); + if (cos_dev_) aclrtFree(cos_dev_); + if (sin_dev_) aclrtFree(sin_dev_); + if (q_rot_dev_) aclrtFree(q_rot_dev_); + if (k_rot_dev_) aclrtFree(k_rot_dev_); + if (q_out_rot_dev_) aclrtFree(q_out_rot_dev_); + if (k_out_rot_dev_) aclrtFree(k_out_rot_dev_); + } + + void operator()(const Tensor positions, const Tensor query, const Tensor key, + const Tensor cos_sin_cache, int64_t head_size, + int64_t rotary_dim, bool is_neox_style, Tensor query_out, + Tensor key_out) const override { + auto stream = static_cast(stream_); + + const int64_t T = query.size(0); + const int64_t Nq = query.size(1); + const int64_t Nkv = key.size(1); + const int64_t D = head_size; + const int64_t R = rotary_dim; + const int64_t max_seq_len = cos_sin_cache.size(0); + + assert(R <= D); + assert(cos_sin_cache.size(1) == R); + + // 1. Gather cos/sin on device via aclnnIndexSelect — fully async. + // No host sync, no D2H copy. Positions stay on device. + { + aclDataType acl_dt_cs = ascend::toAclDtype(query.dtype()); + + // Table tensors: [max_seq_len, R] + std::vector table_shape = {max_seq_len, R}; + std::vector table_strides = {R, 1}; + std::vector table_storage = {max_seq_len * R}; + + aclTensor* t_cos_table = aclCreateTensor( + table_shape.data(), 2, acl_dt_cs, table_strides.data(), 0, + ACL_FORMAT_ND, table_storage.data(), 1, cos_table_dev_); + aclTensor* t_sin_table = aclCreateTensor( + table_shape.data(), 2, acl_dt_cs, table_strides.data(), 0, + ACL_FORMAT_ND, table_storage.data(), 1, sin_table_dev_); + + // Index tensor: positions [T], int64 — stays on device. + std::vector idx_shape = {T}; + std::vector idx_strides = {1}; + std::vector idx_storage = {T}; + aclTensor* t_idx = aclCreateTensor( + idx_shape.data(), 1, ACL_INT64, idx_strides.data(), 0, ACL_FORMAT_ND, + idx_storage.data(), 1, const_cast(positions.data())); + + // Output tensors: [T, R] + std::vector out_shape = {T, R}; + std::vector out_strides = {R, 1}; + std::vector out_storage = {T * R}; + + aclTensor* t_cos_out = + aclCreateTensor(out_shape.data(), 2, acl_dt_cs, out_strides.data(), 0, + ACL_FORMAT_ND, out_storage.data(), 1, cos_dev_); + aclTensor* t_sin_out = + aclCreateTensor(out_shape.data(), 2, acl_dt_cs, out_strides.data(), 0, + ACL_FORMAT_ND, out_storage.data(), 1, sin_dev_); + + // Get workspace sizes and executors for both gathers. + uint64_t ws_cos = 0, ws_sin = 0; + aclOpExecutor *exec_cos = nullptr, *exec_sin = nullptr; + aclnnIndexSelectGetWorkspaceSize(t_cos_table, 0, t_idx, t_cos_out, + &ws_cos, &exec_cos); + aclnnIndexSelectGetWorkspaceSize(t_sin_table, 0, t_idx, t_sin_out, + &ws_sin, &exec_sin); + + // Single workspace buffer large enough for both calls. + uint64_t ws_max = ws_cos > ws_sin ? ws_cos : ws_sin; + auto& arena = ascend::workspacePool().ensure(stream, ws_max); + + aclnnIndexSelect(arena.buf, ws_cos, exec_cos, stream); + aclnnIndexSelect(arena.buf, ws_sin, exec_sin, stream); + + aclDestroyTensor(t_cos_table); + aclDestroyTensor(t_sin_table); + aclDestroyTensor(t_idx); + aclDestroyTensor(t_cos_out); + aclDestroyTensor(t_sin_out); + } + + aclDataType acl_dt = ascend::toAclDtype(query.dtype()); + + if (use_v2_) { + // V2: fused Q+K, in-place, layout=4 (T-first 3D), "half" mode. + // cos/sin shape: [T, 1, R]. + std::vector cs_shape = {T, 1, R}; + std::vector cs_strides = {R, R, 1}; + std::vector cs_storage = {T * R}; + aclTensor* t_cos = + aclCreateTensor(cs_shape.data(), 3, acl_dt, cs_strides.data(), 0, + ACL_FORMAT_ND, cs_storage.data(), 1, cos_dev_); + aclTensor* t_sin = + aclCreateTensor(cs_shape.data(), 3, acl_dt, cs_strides.data(), 0, + ACL_FORMAT_ND, cs_storage.data(), 1, sin_dev_); + + int64_t layout = 4; + if (R == D) { + apply_rope_v2_full(query, key, query_out, key_out, T, Nq, Nkv, D, + acl_dt, t_cos, t_sin, layout, stream); + } else { + apply_rope_v2_partial(query, key, query_out, key_out, T, Nq, Nkv, D, R, + acl_dt, t_cos, t_sin, layout, stream); + } + aclDestroyTensor(t_cos); + aclDestroyTensor(t_sin); + } else { + // V1: separate Q and K calls, non-in-place, [1,T,1,R] cos/sin. + std::vector cs_shape = {1, T, 1, R}; + std::vector cs_strides = {T * R, R, R, 1}; + std::vector cs_storage = {T * R}; + aclTensor* t_cos = + aclCreateTensor(cs_shape.data(), 4, acl_dt, cs_strides.data(), 0, + ACL_FORMAT_ND, cs_storage.data(), 1, cos_dev_); + aclTensor* t_sin = + aclCreateTensor(cs_shape.data(), 4, acl_dt, cs_strides.data(), 0, + ACL_FORMAT_ND, cs_storage.data(), 1, sin_dev_); + + int64_t mode = is_neox_style ? 0 : 1; + apply_rope_v1(query, query_out, T, Nq, D, R, mode, t_cos, t_sin, + q_rot_dev_, q_out_rot_dev_, stream); + apply_rope_v1(key, key_out, T, Nkv, D, R, mode, t_cos, t_sin, k_rot_dev_, + k_out_rot_dev_, stream); + + aclDestroyTensor(t_cos); + aclDestroyTensor(t_sin); + } + } + + private: + size_t cache_elem_size_ = 1; + + // Pre-expanded cos/sin tables on device: [max_seq_len, R]. + // Built once in the constructor with neox/interleave duplication. + void* cos_table_dev_ = nullptr; + void* sin_table_dev_ = nullptr; + size_t table_bytes_ = 0; + + // true when V2 hardware constraints are met (neox, D∈{64,128}, bf16). + bool use_v2_ = false; + + // Device buffers for gathered [T, R] cos/sin (shared by V1 and V2). + void* cos_dev_ = nullptr; + void* sin_dev_ = nullptr; + size_t gathered_cs_bytes_ = 0; + + // Scratch for partial rotation (R < D). + void* q_rot_dev_ = nullptr; + void* k_rot_dev_ = nullptr; + void* q_out_rot_dev_ = nullptr; + void* k_out_rot_dev_ = nullptr; + + // --- V2 helpers (neox bf16, D∈{64,128}) --- + + void apply_rope_v2_full(const Tensor& q, const Tensor& k, Tensor& q_out, + Tensor& k_out, int64_t T, int64_t Nq, int64_t Nkv, + int64_t D, aclDataType acl_dt, aclTensor* t_cos, + aclTensor* t_sin, int64_t layout, + aclrtStream stream) const { + size_t elem_sz = q.element_size(); + if (q.data() != q_out.data()) { + aclrtMemcpyAsync(const_cast(q_out.data()), + static_cast(T * Nq * D) * elem_sz, q.data(), + static_cast(T * Nq * D) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + if (k.data() != k_out.data()) { + size_t k_elem_sz = k.element_size(); + aclrtMemcpyAsync(const_cast(k_out.data()), + static_cast(T * Nkv * D) * k_elem_sz, k.data(), + static_cast(T * Nkv * D) * k_elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + std::vector q_shape = {T, Nq, D}; + std::vector q_strides = {Nq * D, D, 1}; + std::vector q_storage = {T * Nq * D}; + std::vector k_shape = {T, Nkv, D}; + std::vector k_strides = {Nkv * D, D, 1}; + std::vector k_storage = {T * Nkv * D}; + aclTensor* t_q = aclCreateTensor( + q_shape.data(), 3, acl_dt, q_strides.data(), 0, ACL_FORMAT_ND, + q_storage.data(), 1, const_cast(q_out.data())); + aclTensor* t_k = aclCreateTensor( + k_shape.data(), 3, acl_dt, k_strides.data(), 0, ACL_FORMAT_ND, + k_storage.data(), 1, const_cast(k_out.data())); + uint64_t ws = 0; + aclOpExecutor* exec = nullptr; + aclnnApplyRotaryPosEmbV2GetWorkspaceSize( + t_q, t_k, t_cos, t_sin, layout, const_cast("half"), &ws, &exec); + auto& arena = ascend::workspacePool().ensure(stream, ws); + aclnnApplyRotaryPosEmbV2(arena.buf, ws, exec, stream); + aclDestroyTensor(t_q); + aclDestroyTensor(t_k); + } + + void apply_rope_v2_partial(const Tensor& q, const Tensor& k, Tensor& q_out, + Tensor& k_out, int64_t T, int64_t Nq, int64_t Nkv, + int64_t D, int64_t R, aclDataType acl_dt, + aclTensor* t_cos, aclTensor* t_sin, int64_t layout, + aclrtStream stream) const { + size_t elem_sz = q.element_size(); + size_t k_elem_sz = k.element_size(); + const int64_t pass = D - R; + + for (int64_t i = 0; i < T * Nq; ++i) { + aclrtMemcpyAsync(static_cast(q_rot_dev_) + + static_cast(i * R) * elem_sz, + static_cast(R) * elem_sz, + static_cast(q.data()) + + static_cast(i * D) * elem_sz, + static_cast(R) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + for (int64_t i = 0; i < T * Nkv; ++i) { + aclrtMemcpyAsync(static_cast(k_rot_dev_) + + static_cast(i * R) * k_elem_sz, + static_cast(R) * k_elem_sz, + static_cast(k.data()) + + static_cast(i * D) * k_elem_sz, + static_cast(R) * k_elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + std::vector qr_shape = {T, Nq, R}; + std::vector qr_strides = {Nq * R, R, 1}; + std::vector qr_storage = {T * Nq * R}; + std::vector kr_shape = {T, Nkv, R}; + std::vector kr_strides = {Nkv * R, R, 1}; + std::vector kr_storage = {T * Nkv * R}; + aclTensor* t_q_rot = + aclCreateTensor(qr_shape.data(), 3, acl_dt, qr_strides.data(), 0, + ACL_FORMAT_ND, qr_storage.data(), 1, q_rot_dev_); + aclTensor* t_k_rot = + aclCreateTensor(kr_shape.data(), 3, acl_dt, kr_strides.data(), 0, + ACL_FORMAT_ND, kr_storage.data(), 1, k_rot_dev_); + uint64_t ws = 0; + aclOpExecutor* exec = nullptr; + aclnnApplyRotaryPosEmbV2GetWorkspaceSize(t_q_rot, t_k_rot, t_cos, t_sin, + layout, const_cast("half"), + &ws, &exec); + auto& arena = ascend::workspacePool().ensure(stream, ws); + aclnnApplyRotaryPosEmbV2(arena.buf, ws, exec, stream); + aclDestroyTensor(t_q_rot); + aclDestroyTensor(t_k_rot); + + for (int64_t i = 0; i < T * Nq; ++i) { + aclrtMemcpyAsync(static_cast(const_cast(q_out.data())) + + static_cast(i * D) * elem_sz, + static_cast(R) * elem_sz, + static_cast(q_rot_dev_) + + static_cast(i * R) * elem_sz, + static_cast(R) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + aclrtMemcpyAsync(static_cast(const_cast(q_out.data())) + + static_cast(i * D + R) * elem_sz, + static_cast(pass) * elem_sz, + static_cast(q.data()) + + static_cast(i * D + R) * elem_sz, + static_cast(pass) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + for (int64_t i = 0; i < T * Nkv; ++i) { + aclrtMemcpyAsync(static_cast(const_cast(k_out.data())) + + static_cast(i * D) * k_elem_sz, + static_cast(R) * k_elem_sz, + static_cast(k_rot_dev_) + + static_cast(i * R) * k_elem_sz, + static_cast(R) * k_elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + aclrtMemcpyAsync(static_cast(const_cast(k_out.data())) + + static_cast(i * D + R) * k_elem_sz, + static_cast(pass) * k_elem_sz, + static_cast(k.data()) + + static_cast(i * D + R) * k_elem_sz, + static_cast(pass) * k_elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + } + + // --- V1 helper (fallback for non-neox, fp16, or D not in {64,128}) --- + + void apply_rope_v1(const Tensor& x, Tensor& out, int64_t T, int64_t N, + int64_t D, int64_t R, int64_t mode, aclTensor* t_cos, + aclTensor* t_sin, void* x_rot_dev, void* out_rot_dev, + aclrtStream stream) const { + aclDataType acl_dt = ascend::toAclDtype(x.dtype()); + size_t elem_sz = x.element_size(); + + if (R < D) { + for (int64_t i = 0; i < T * N; ++i) { + aclrtMemcpyAsync(static_cast(x_rot_dev) + + static_cast(i * R) * elem_sz, + static_cast(R) * elem_sz, + static_cast(x.data()) + + static_cast(i * D) * elem_sz, + static_cast(R) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + std::vector rot_sh = {1, T, N, R}; + std::vector rot_st = {T * N * R, N * R, R, 1}; + std::vector rot_storage = {T * N * R}; + aclTensor* t_x_rot = + aclCreateTensor(rot_sh.data(), 4, acl_dt, rot_st.data(), 0, + ACL_FORMAT_ND, rot_storage.data(), 1, x_rot_dev); + aclTensor* t_out_rot = + aclCreateTensor(rot_sh.data(), 4, acl_dt, rot_st.data(), 0, + ACL_FORMAT_ND, rot_storage.data(), 1, out_rot_dev); + uint64_t ws = 0; + aclOpExecutor* exec = nullptr; + aclnnRotaryPositionEmbeddingGetWorkspaceSize(t_x_rot, t_cos, t_sin, mode, + t_out_rot, &ws, &exec); + auto& arena = ascend::workspacePool().ensure(stream, ws); + aclnnRotaryPositionEmbedding(arena.buf, ws, exec, stream); + + const int64_t pass = D - R; + for (int64_t i = 0; i < T * N; ++i) { + aclrtMemcpyAsync(static_cast(const_cast(out.data())) + + static_cast(i * D) * elem_sz, + static_cast(R) * elem_sz, + static_cast(out_rot_dev) + + static_cast(i * R) * elem_sz, + static_cast(R) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + aclrtMemcpyAsync(static_cast(const_cast(out.data())) + + static_cast(i * D + R) * elem_sz, + static_cast(pass) * elem_sz, + static_cast(x.data()) + + static_cast(i * D + R) * elem_sz, + static_cast(pass) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + aclDestroyTensor(t_x_rot); + aclDestroyTensor(t_out_rot); + } else { + std::vector full_sh = {1, T, N, D}; + std::vector full_st = {T * N * D, N * D, D, 1}; + std::vector full_storage = {T * N * D}; + aclTensor* t_x = aclCreateTensor( + full_sh.data(), 4, acl_dt, full_st.data(), 0, ACL_FORMAT_ND, + full_storage.data(), 1, const_cast(x.data())); + aclTensor* t_out = aclCreateTensor( + full_sh.data(), 4, acl_dt, full_st.data(), 0, ACL_FORMAT_ND, + full_storage.data(), 1, const_cast(out.data())); + uint64_t ws = 0; + aclOpExecutor* exec = nullptr; + aclnnRotaryPositionEmbeddingGetWorkspaceSize(t_x, t_cos, t_sin, mode, + t_out, &ws, &exec); + auto& arena = ascend::workspacePool().ensure(stream, ws); + aclnnRotaryPositionEmbedding(arena.buf, ws, exec, stream); + aclDestroyTensor(t_x); + aclDestroyTensor(t_out); + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/swiglu/kernel.h b/src/ascend/swiglu/kernel.h new file mode 100644 index 0000000..c7d31e7 --- /dev/null +++ b/src/ascend/swiglu/kernel.h @@ -0,0 +1,70 @@ +#ifndef INFINI_OPS_ASCEND_SWIGLU_KERNEL_H_ +#define INFINI_OPS_ASCEND_SWIGLU_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_mul.h" +#include "aclnn_silu.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/swiglu.h" +#include "data_type.h" +#include "operator.h" + +namespace infini::ops { + +// Implements SwiGLU as two ACLNN calls: silu(gate) into a temp buffer, +// then elementwise mul(input, temp) into out. +// aclnnSiluMul was not used because it fuses silu_AND_mul on the same +// tensor (x * silu(x)), whereas SwiGLU requires input * silu(gate) — +// two distinct inputs. +template <> +class Operator : public Swiglu { + public: + Operator(const Tensor input, const Tensor gate, Tensor out) + : Swiglu(input, gate, out) { + size_t nbytes = input.numel() * kDataTypeToSize.at(input.dtype()); + aclrtMalloc(&temp_buf_, nbytes, ACL_MEM_MALLOC_NORMAL_ONLY); + } + + ~Operator() { aclrtFree(temp_buf_); } + + void operator()(const Tensor input, const Tensor gate, + Tensor out) const override { + // temp_buf_ is a contiguous scratch buffer; give it contiguous strides. + Tensor temp_t{temp_buf_, gate.shape(), gate.dtype(), gate.device()}; + + auto t_in = ascend::buildAclTensor(input); + auto t_gate = ascend::buildAclTensor(gate); + auto t_out = ascend::buildAclTensor(out); + auto t_temp = ascend::buildAclTensor(temp_t); + + uint64_t ws_needed = 0; + aclOpExecutor* exec = nullptr; + auto stream = static_cast(stream_); + + // Step 1: silu(gate) -> temp. SwiGLU = input * silu(gate). + aclnnSiluGetWorkspaceSize(t_gate, t_temp, &ws_needed, &exec); + auto& silu_arena = ascend::workspacePool().ensure(stream, ws_needed); + aclnnSilu(silu_arena.buf, ws_needed, exec, stream); + + // Step 2: mul(input, temp) -> out. + uint64_t mul_ws = 0; + exec = nullptr; + aclnnMulGetWorkspaceSize(t_in, t_temp, t_out, &mul_ws, &exec); + auto& mul_arena = ascend::workspacePool().ensure(stream, mul_ws); + aclnnMul(mul_arena.buf, mul_ws, exec, stream); + + aclDestroyTensor(t_in); + aclDestroyTensor(t_gate); + aclDestroyTensor(t_out); + aclDestroyTensor(t_temp); + } + + private: + void* temp_buf_ = nullptr; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/rotary_embedding.h b/src/base/rotary_embedding.h index a38b20e..70989fa 100644 --- a/src/base/rotary_embedding.h +++ b/src/base/rotary_embedding.h @@ -15,8 +15,8 @@ class RotaryEmbedding : public Operator { int64_t rotary_dim, bool is_neox_style, Tensor query_out, Tensor key_out) : num_tokens_{query.size(0)}, - num_heads_{query.size(1)}, - num_kv_heads_{key.size(1)}, + num_heads_{static_cast(query.size(1))}, + num_kv_heads_{static_cast(key.size(1))}, head_size_{head_size}, rotary_dim_{rotary_dim}, is_neox_style_{is_neox_style}, From 2ccd53fb2b963648338dd09b05b5143793dd2fd1 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Wed, 8 Apr 2026 16:48:25 +0800 Subject: [PATCH 2/8] test(ascend): add NPU stream injection and new operator tests Pass stream to all CANN ops in existing tests; add FlashAttention, ReshapeAndCache, RotaryEmbedding, and E2E LLaMA layer tests. --- tests/test_add.py | 13 +- tests/test_causal_softmax.py | 9 +- tests/test_e2e_layer.py | 418 ++++++++++++++++++++++++++++++ tests/test_flash_attention.py | 442 ++++++++++++++++++++++++++++++++ tests/test_gemm.py | 8 +- tests/test_reshape_and_cache.py | 152 +++++++++++ tests/test_rms_norm.py | 7 +- tests/test_rotary_embedding.py | 266 +++++++++++++++++++ tests/test_swiglu.py | 7 +- 9 files changed, 1312 insertions(+), 10 deletions(-) create mode 100644 tests/test_e2e_layer.py create mode 100644 tests/test_flash_attention.py create mode 100644 tests/test_reshape_and_cache.py create mode 100644 tests/test_rotary_embedding.py diff --git a/tests/test_add.py b/tests/test_add.py index 8b8166c..f560435 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -2,7 +2,13 @@ import pytest import torch -from tests.utils import Payload, empty_strided, randint_strided, randn_strided +from tests.utils import ( + Payload, + empty_strided, + get_npu_stream, + randint_strided, + randn_strided, +) _INT_DTYPES = (torch.int16, torch.int32, torch.int64) @@ -63,7 +69,10 @@ def test_add( def _add(input, other, out): - infini.ops.add(input, other, out) + if input.device.type == "npu": + infini.ops.add(input, other, out, stream=get_npu_stream(input)) + else: + infini.ops.add(input, other, out) return out diff --git a/tests/test_causal_softmax.py b/tests/test_causal_softmax.py index 8b35457..df4894c 100644 --- a/tests/test_causal_softmax.py +++ b/tests/test_causal_softmax.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, empty_strided, randn_strided +from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided @pytest.mark.auto_act_and_assert @@ -40,7 +40,10 @@ def test_causal_softmax(shape, input_strides, out_strides, dtype, device, rtol, def _causal_softmax(input, out): - infini.ops.causal_softmax(input, out) + if input.device.type == "npu": + infini.ops.causal_softmax(input, out, stream=get_npu_stream(input)) + else: + infini.ops.causal_softmax(input, out) return out @@ -48,7 +51,7 @@ def _causal_softmax(input, out): def _torch_causal_softmax(input, out): mask = torch.tril(torch.ones_like(input), diagonal=-1).flip(dims=[-2, -1]) masked = torch.where(mask == 1, -torch.inf, input.to(torch.float32)) - result = torch.nn.functional.softmax(masked, dim=-1, dtype=input.dtype) + result = torch.nn.functional.softmax(masked, dim=-1) out.copy_(result) return out diff --git a/tests/test_e2e_layer.py b/tests/test_e2e_layer.py new file mode 100644 index 0000000..92df9a2 --- /dev/null +++ b/tests/test_e2e_layer.py @@ -0,0 +1,418 @@ +import infini.ops +import pytest +import torch + +from tests.utils import get_npu_stream, randn_strided, randint_strided + + +def _stream_kw(tensor): + if tensor.device.type == "npu": + return {"stream": get_npu_stream(tensor)} + + return {} + + +def _ref_rms_norm(x, weight, eps): + rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + eps) + + return (x / rms) * weight + + +def _ref_rope( + positions, query, key, cos_sin_cache, head_size, rotary_dim, is_neox_style +): + T = query.size(0) + R = rotary_dim + half_R = R // 2 + cos_half = cos_sin_cache[:, :half_R] + sin_half = cos_sin_cache[:, half_R:] + + def apply_rope(x): + out = x.clone() + + for t in range(T): + p = positions[t].item() + c = cos_half[p] + s = sin_half[p] + + if is_neox_style: + x1 = x[t, :, :half_R] + x2 = x[t, :, half_R:R] + out[t, :, :half_R] = c * x1 - s * x2 + out[t, :, half_R:R] = c * x2 + s * x1 + else: + x1 = x[t, :, 0::2] + x2 = x[t, :, 1::2] + out[t, :, 0::2] = c * x1 - s * x2 + out[t, :, 1::2] = c * x2 + s * x1 + + return out + + return apply_rope(query), apply_rope(key) + + +def _ref_sdpa(query, key, value, num_heads, num_kv_heads, head_size, scale, causal): + q = query.transpose(0, 1).float() + k = key.transpose(0, 1).float() + v = value.transpose(0, 1).float() + + if num_kv_heads < num_heads: + ratio = num_heads // num_kv_heads + k = k.repeat_interleave(ratio, dim=0) + v = v.repeat_interleave(ratio, dim=0) + + out = torch.nn.functional.scaled_dot_product_attention( + q.unsqueeze(0), + k.unsqueeze(0), + v.unsqueeze(0), + scale=scale, + is_causal=causal, + ) + + return out.squeeze(0).transpose(0, 1) + + +def _infiniops_layer( + hidden, + positions, + cos_sin_cache, + input_norm_w, + qkv_proj_w, + o_proj_w, + gate_proj_w, + up_proj_w, + down_proj_w, + post_norm_w, + num_heads, + num_kv_heads, + head_size, + rotary_dim, + intermediate_size, + is_neox_style, + eps, + scale, + num_tokens, +): + """Run one LLaMA decoder layer using InfiniOps kernels.""" + kw = _stream_kw(hidden) + dtype = hidden.dtype + device = hidden.device + hidden_size = hidden.size(-1) + + # Save residual. + residual = hidden.clone() + + # 1. Input RMSNorm. + normed = torch.empty_like(hidden) + infini.ops.rms_norm(hidden, input_norm_w, eps, normed, **kw) + + # 2. QKV projection: [T, D] @ [D, (N+2*Nkv)*H] -> [T, (N+2*Nkv)*H]. + qkv_dim = (num_heads + 2 * num_kv_heads) * head_size + qkv = torch.empty(num_tokens, qkv_dim, dtype=dtype, device=device) + infini.ops.gemm(normed, qkv_proj_w, 1.0, 0.0, False, False, qkv, **kw) + + # Split Q, K, V. + q = ( + qkv[:, : num_heads * head_size] + .reshape( + num_tokens, + num_heads, + head_size, + ) + .contiguous() + ) + k = ( + qkv[:, num_heads * head_size : (num_heads + num_kv_heads) * head_size] + .reshape( + num_tokens, + num_kv_heads, + head_size, + ) + .contiguous() + ) + v = ( + qkv[:, (num_heads + num_kv_heads) * head_size :] + .reshape( + num_tokens, + num_kv_heads, + head_size, + ) + .contiguous() + ) + + # 3. RoPE. + q_rot = torch.empty_like(q) + k_rot = torch.empty_like(k) + infini.ops.rotary_embedding( + positions, + q, + k, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + q_rot, + k_rot, + **kw, + ) + + # 4. Flash attention (single-sequence prefill, causal). + attn_out = torch.empty( + num_tokens, + num_heads, + head_size, + dtype=dtype, + device=device, + ) + infini.ops.flash_attention( + q_rot, + k_rot, + v, + None, + None, + None, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + 0, + attn_out, + **kw, + ) + + # 5. O projection: [T, N*H] @ [N*H, D] -> [T, D]. + attn_2d = attn_out.reshape(num_tokens, num_heads * head_size) + o_out = torch.empty(num_tokens, hidden_size, dtype=dtype, device=device) + infini.ops.gemm(attn_2d, o_proj_w, 1.0, 0.0, False, False, o_out, **kw) + + # 6. Residual add. + after_attn = torch.empty_like(residual) + infini.ops.add(residual, o_out, after_attn, **kw) + + # 7. Post-attention RMSNorm. + residual2 = after_attn.clone() + normed2 = torch.empty_like(after_attn) + infini.ops.rms_norm(after_attn, post_norm_w, eps, normed2, **kw) + + # 8. Gate + up projections. + gate = torch.empty(num_tokens, intermediate_size, dtype=dtype, device=device) + up = torch.empty(num_tokens, intermediate_size, dtype=dtype, device=device) + infini.ops.gemm(normed2, gate_proj_w, 1.0, 0.0, False, False, gate, **kw) + infini.ops.gemm(normed2, up_proj_w, 1.0, 0.0, False, False, up, **kw) + + # 9. SwiGLU: ``up * silu(gate)``. + ffn = torch.empty(num_tokens, intermediate_size, dtype=dtype, device=device) + infini.ops.swiglu(up, gate, ffn, **kw) + + # 10. Down projection: [T, FFN] @ [FFN, D] -> [T, D]. + down = torch.empty(num_tokens, hidden_size, dtype=dtype, device=device) + infini.ops.gemm(ffn, down_proj_w, 1.0, 0.0, False, False, down, **kw) + + # 11. Second residual add. + output = torch.empty_like(residual2) + infini.ops.add(residual2, down, output, **kw) + + return output + + +def _reference_layer( + hidden, + positions, + cos_sin_cache, + input_norm_w, + qkv_proj_w, + o_proj_w, + gate_proj_w, + up_proj_w, + down_proj_w, + post_norm_w, + num_heads, + num_kv_heads, + head_size, + rotary_dim, + intermediate_size, + is_neox_style, + eps, + scale, + num_tokens, +): + """PyTorch float32 reference for one LLaMA decoder layer.""" + # Compute in float32 on CPU for accuracy. + h = hidden.float().cpu() + pos = positions.cpu() + csc = cos_sin_cache.float().cpu() + inw = input_norm_w.float().cpu() + qkvw = qkv_proj_w.float().cpu() + ow = o_proj_w.float().cpu() + gw = gate_proj_w.float().cpu() + uw = up_proj_w.float().cpu() + dw = down_proj_w.float().cpu() + pnw = post_norm_w.float().cpu() + + # 1. Input RMSNorm. + residual = h.clone() + normed = _ref_rms_norm(h, inw, eps) + + # 2. QKV projection. + qkv = normed @ qkvw + + q = qkv[:, : num_heads * head_size].reshape(num_tokens, num_heads, head_size) + k = qkv[:, num_heads * head_size : (num_heads + num_kv_heads) * head_size].reshape( + num_tokens, + num_kv_heads, + head_size, + ) + v = qkv[:, (num_heads + num_kv_heads) * head_size :].reshape( + num_tokens, + num_kv_heads, + head_size, + ) + + # 3. RoPE. + q_rot, k_rot = _ref_rope( + pos, + q, + k, + csc, + head_size, + rotary_dim, + is_neox_style, + ) + + # 4. SDPA. + attn_out = _ref_sdpa( + q_rot, k_rot, v, num_heads, num_kv_heads, head_size, scale, causal=True + ) + + # 5. O projection. + attn_2d = attn_out.reshape(num_tokens, num_heads * head_size) + o_out = attn_2d @ ow + + # 6. Residual add. + after_attn = residual + o_out + + # 7. Post-attention RMSNorm. + residual2 = after_attn.clone() + normed2 = _ref_rms_norm(after_attn, pnw, eps) + + # 8. Gate + up projections. + gate = normed2 @ gw + up = normed2 @ uw + + # 9. SwiGLU: ``up * silu(gate)``. + ffn = up * (gate * torch.sigmoid(gate)) + + # 10. Down projection. + down = ffn @ dw + + # 11. Second residual add. + output = residual2 + down + + return output.to(hidden.dtype).to(hidden.device) + + +def _make_rope_cache(max_seq_len, rotary_dim, dtype, device): + """Build a proper RoPE cos/sin cache (bounded to [-1, 1]).""" + freq = 1.0 / (10000.0 ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim)) + t = torch.arange(max_seq_len, dtype=torch.float32) + angles = torch.outer(t, freq) # [max_seq_len, half_dim] + cos_half = torch.cos(angles).to(dtype=dtype, device=device) + sin_half = torch.sin(angles).to(dtype=dtype, device=device) + + return torch.cat([cos_half, sin_half], dim=-1) + + +@pytest.mark.parametrize("device", ("npu",)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 5e-3, 5e-3), + (torch.bfloat16, 1e-2, 2e-2), + ), +) +def test_llama_layer(device, dtype, rtol, atol): + """End-to-end test of a LLaMA decoder layer using InfiniOps kernels.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + # Small LLaMA-like model config. + hidden_size = 512 + num_heads = 8 + num_kv_heads = 2 + head_size = hidden_size // num_heads + intermediate_size = 1024 + num_tokens = 1 + max_seq_len = 16 + rotary_dim = head_size + is_neox_style = True + eps = 1e-6 + scale = 1.0 / head_size**0.5 + + def _scaled_weight(*shape): + return randn_strided(shape, None, dtype=dtype, device=device) / shape[0] ** 0.5 + + # Random weights (stored as [in_features, out_features], Xavier-scaled). + qkv_proj_w = _scaled_weight( + hidden_size, + (num_heads + 2 * num_kv_heads) * head_size, + ) + o_proj_w = _scaled_weight(num_heads * head_size, hidden_size) + gate_proj_w = _scaled_weight(hidden_size, intermediate_size) + up_proj_w = _scaled_weight(hidden_size, intermediate_size) + down_proj_w = _scaled_weight(intermediate_size, hidden_size) + input_norm_w = torch.ones(hidden_size, dtype=dtype, device=device) + post_norm_w = torch.ones(hidden_size, dtype=dtype, device=device) + + # Proper cos/sin cache from frequency decomposition (bounded [-1, 1]). + cos_sin_cache = _make_rope_cache(max_seq_len, rotary_dim, dtype, device) + positions = randint_strided( + 0, + max_seq_len, + (num_tokens,), + None, + dtype=torch.int64, + device=device, + ) + + # Input hidden states scaled to prevent value explosion through layers. + hidden = ( + randn_strided( + (num_tokens, hidden_size), + None, + dtype=dtype, + device=device, + ) + / hidden_size**0.5 + ) + + common = dict( + positions=positions, + cos_sin_cache=cos_sin_cache, + input_norm_w=input_norm_w, + qkv_proj_w=qkv_proj_w, + o_proj_w=o_proj_w, + gate_proj_w=gate_proj_w, + up_proj_w=up_proj_w, + down_proj_w=down_proj_w, + post_norm_w=post_norm_w, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + rotary_dim=rotary_dim, + intermediate_size=intermediate_size, + is_neox_style=is_neox_style, + eps=eps, + scale=scale, + num_tokens=num_tokens, + ) + + infini_out = _infiniops_layer(hidden, **common) + ref_out = _reference_layer(hidden, **common) + + max_diff = (infini_out.float() - ref_out.float()).abs().max().item() + assert torch.allclose(infini_out, ref_out, rtol=rtol, atol=atol), ( + f"Max diff: {max_diff}" + ) diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py new file mode 100644 index 0000000..4b8be3f --- /dev/null +++ b/tests/test_flash_attention.py @@ -0,0 +1,442 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, get_npu_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size", + ( + (32, 32, 128), # MHA + (32, 8, 128), # GQA (4x) + (16, 4, 64), # GQA (4x), smaller + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_flash_attention_prefill_single( + num_heads, + num_kv_heads, + head_size, + dtype, + rtol, + atol, + device, +): + """Single sequence prefill (no block table).""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_tokens = 16 + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_tokens, num_heads, head_size), None, dtype=dtype, device=device + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + output = torch.empty((num_tokens, num_heads, head_size), dtype=dtype, device=device) + + return Payload( + lambda q, k, v, o: _flash_attention( + q, + k, + v, + None, + None, + None, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + 0, + o, + ), + lambda q, k, v, o: _ref_flash_attention( + q, + k, + v, + num_heads, + num_kv_heads, + head_size, + scale, + causal=True, + ), + (query, key, value, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size", + ((32, 8, 128),), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_flash_attention_prefill_multi( + num_heads, + num_kv_heads, + head_size, + dtype, + rtol, + atol, + device, +): + """Multi-sequence prefill with cu_seqlens.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + seq_lens = [8, 12, 4] + num_tokens = sum(seq_lens) + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_tokens, num_heads, head_size), None, dtype=dtype, device=device + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + output = torch.empty((num_tokens, num_heads, head_size), dtype=dtype, device=device) + + cu_seqlens_q = torch.tensor( + [0] + [sum(seq_lens[: i + 1]) for i in range(len(seq_lens))], + dtype=torch.int64, + device=device, + ) + cu_seqlens_kv = cu_seqlens_q.clone() + + return Payload( + lambda q, k, v, o: _flash_attention( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + None, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + 0, + o, + ), + lambda q, k, v, o: _ref_flash_attention_multi( + q, + k, + v, + seq_lens, + seq_lens, + num_heads, + num_kv_heads, + head_size, + scale, + causal=True, + ), + (query, key, value, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, block_size", + ( + (32, 8, 128, 128), + (16, 4, 64, 128), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_flash_attention_decode( + num_heads, + num_kv_heads, + head_size, + block_size, + dtype, + rtol, + atol, + device, +): + """Decode phase: single token per request with paged KV cache.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_reqs = 3 + kv_len = 16 # Total KV length per request. + num_blocks_per_req = (kv_len + block_size - 1) // block_size + num_blocks = num_reqs * num_blocks_per_req + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_reqs, num_heads, head_size), None, dtype=dtype, device=device + ) + # Paged KV cache: vLLM standard layout [num_blocks, block_size, KV_N, D]. + kv_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + output = torch.empty((num_reqs, num_heads, head_size), dtype=dtype, device=device) + + # Block table: request i uses blocks [i*num_blocks_per_req, ...]. + block_table = torch.zeros( + (num_reqs, num_blocks_per_req), dtype=torch.int32, device=device + ) + for i in range(num_reqs): + for j in range(num_blocks_per_req): + block_table[i, j] = i * num_blocks_per_req + j + + cu_seqlens_q = torch.arange(0, num_reqs + 1, dtype=torch.int64, device=device) + cu_seqlens_kv = torch.tensor( + [i * kv_len for i in range(num_reqs + 1)], dtype=torch.int64, device=device + ) + + return Payload( + lambda q, k, v, o: _flash_attention( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + block_size, + o, + ), + lambda q, k, v, o: _ref_flash_attention_paged( + q, + k, + block_table, + cu_seqlens_q, + cu_seqlens_kv, + num_heads, + num_kv_heads, + head_size, + block_size, + scale, + causal=True, + ), + (query, kv_cache, kv_cache, output), + {}, + rtol=rtol, + atol=atol, + ) + + +def _flash_attention( + query, + key, + value, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + causal, + window_left, + window_right, + block_size, + output, +): + if query.device.type == "npu": + infini.ops.flash_attention( + query, + key, + value, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + causal, + window_left, + window_right, + block_size, + output, + stream=get_npu_stream(query), + ) + else: + infini.ops.flash_attention( + query, + key, + value, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + causal, + window_left, + window_right, + block_size, + output, + ) + + return output + + +def _ref_flash_attention( + query, key, value, num_heads, num_kv_heads, head_size, scale, causal=True +): + """PyTorch SDPA reference for single-sequence prefill.""" + # [T, N, D] -> [N, T, D] + q = query.transpose(0, 1).float() + k = key.transpose(0, 1).float() + v = value.transpose(0, 1).float() + + # GQA: expand K/V to match num_heads. + if num_kv_heads < num_heads: + ratio = num_heads // num_kv_heads + k = k.repeat_interleave(ratio, dim=0) + v = v.repeat_interleave(ratio, dim=0) + + # [N, T, D] -> [1, N, T, D] for scaled_dot_product_attention. + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + + out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, scale=scale, is_causal=causal + ) + + # [1, N, T, D] -> [T, N, D] -> original dtype. + return out.squeeze(0).transpose(0, 1).to(query.dtype) + + +def _ref_flash_attention_multi( + query, + key, + value, + seq_lens_q, + seq_lens_kv, + num_heads, + num_kv_heads, + head_size, + scale, + causal=True, +): + """PyTorch SDPA reference for multi-sequence prefill.""" + outputs = [] + offset = 0 + for sq, sk in zip(seq_lens_q, seq_lens_kv): + q = query[offset : offset + sq] + k = key[offset : offset + sq] + v = value[offset : offset + sq] + out = _ref_flash_attention( + q, k, v, num_heads, num_kv_heads, head_size, scale, causal + ) + outputs.append(out) + offset += sq + + return torch.cat(outputs, dim=0) + + +def _ref_flash_attention_paged( + query, + kv_cache_arg, + block_table, + cu_seqlens_q, + cu_seqlens_kv, + num_heads, + num_kv_heads, + head_size, + block_size, + scale, + causal=True, +): + """PyTorch SDPA reference for decode with paged KV cache.""" + cu_kv = cu_seqlens_kv.cpu() + bt = block_table.cpu() + cache = kv_cache_arg.cpu() + q_cpu = query.cpu() + num_reqs = bt.size(0) + outputs = [] + + for i in range(num_reqs): + q = q_cpu[i : i + 1] # [1, N, D] + kv_len = int(cu_kv[i + 1] - cu_kv[i]) + + # Gather KV from paged cache. + # cache: [num_blocks, KV_N, block_size, D] + blocks = bt[i] + k_pages = [] + v_pages = [] + remaining = kv_len + for b in blocks: + if remaining <= 0: + break + take = min(remaining, block_size) + # cache layout: [num_blocks, block_size, KV_N, D] + # Slice [take, KV_N, D], transpose to [KV_N, take, D] for cat. + k_pages.append(cache[int(b.item()), :take, :, :].transpose(0, 1)) + v_pages.append(cache[int(b.item()), :take, :, :].transpose(0, 1)) + remaining -= take + k = torch.cat(k_pages, dim=1) # [KV_N, kv_len, D] + v = torch.cat(v_pages, dim=1) + + # Decode: Q_S=1 attends to all past KV positions; causal masking is + # not applicable here (it would mask everything beyond position 0). + out = _ref_flash_attention( + q, # [1, N, D] - already TND format + k.transpose(0, 1), # [KV_N, kv_len, D] -> [kv_len, KV_N, D] + v.transpose(0, 1), + num_heads, + num_kv_heads, + head_size, + scale, + causal=False, + ) + outputs.append(out) + + return torch.cat(outputs, dim=0).to(query.device) diff --git a/tests/test_gemm.py b/tests/test_gemm.py index af8b44f..3f48562 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -86,7 +86,13 @@ def test_gemm( def _gemm(a, b, alpha, beta, trans_a, trans_b, c, implementation_index=0): if a.device.type == "npu": infini.ops.gemm( - a, b, alpha, beta, trans_a, trans_b, c, + a, + b, + alpha, + beta, + trans_a, + trans_b, + c, stream=get_npu_stream(a), ) else: diff --git a/tests/test_reshape_and_cache.py b/tests/test_reshape_and_cache.py new file mode 100644 index 0000000..813afc3 --- /dev/null +++ b/tests/test_reshape_and_cache.py @@ -0,0 +1,152 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, get_npu_stream, randn_strided + +# ReshapeAndCache only works on NPU (aclrtMemcpy-based), so tests only +# parametrize on float16/bfloat16 and use explicit device parametrization. + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_tokens, num_kv_heads, head_size, num_blocks, block_size", + ( + (1, 8, 128, 4, 16), + (4, 8, 128, 4, 16), + (8, 4, 64, 8, 32), + (16, 2, 128, 8, 64), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_reshape_and_cache_contiguous( + num_tokens, + num_kv_heads, + head_size, + num_blocks, + block_size, + dtype, + rtol, + atol, + device, +): + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + # Layout: [2, num_blocks, block_size, num_kv_heads, head_size] + # Index 0 = key cache, index 1 = value cache. + kv_cache = torch.zeros( + (2, num_blocks, block_size, num_kv_heads, head_size), + dtype=dtype, + device=device, + ) + # Contiguous slot mapping: token i -> slot i. + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) + + return Payload( + _reshape_and_cache, + _ref_reshape_and_cache, + (key, value, kv_cache, slot_mapping, kv_cache), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_tokens, num_kv_heads, head_size, num_blocks, block_size", + ( + (4, 8, 128, 4, 16), + (8, 4, 64, 8, 32), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_reshape_and_cache_noncontiguous_slots( + num_tokens, + num_kv_heads, + head_size, + num_blocks, + block_size, + dtype, + rtol, + atol, + device, +): + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + kv_cache = torch.zeros( + (2, num_blocks, block_size, num_kv_heads, head_size), + dtype=dtype, + device=device, + ) + # Non-contiguous slots: skip every other slot. + slot_mapping = torch.tensor( + [i * 2 for i in range(num_tokens)], dtype=torch.int64, device=device + ) + + return Payload( + _reshape_and_cache, + _ref_reshape_and_cache, + (key, value, kv_cache, slot_mapping, kv_cache), + {}, + rtol=rtol, + atol=atol, + ) + + +def _reshape_and_cache(key, value, kv_cache, slot_mapping, kv_cache_out): + if key.device.type == "npu": + infini.ops.reshape_and_cache( + key, value, kv_cache, slot_mapping, kv_cache_out, stream=get_npu_stream(key) + ) + else: + infini.ops.reshape_and_cache(key, value, kv_cache, slot_mapping, kv_cache_out) + + return kv_cache_out + + +def _ref_reshape_and_cache(key, value, kv_cache, slot_mapping, kv_cache_out): + kv_cache_out = kv_cache_out.clone() + slots = slot_mapping.cpu() + block_size = kv_cache_out.size(2) + + for i in range(key.size(0)): + slot = int(slots[i].item()) + + if slot < 0: + continue + + block_idx = slot // block_size + offset = slot % block_size + kv_cache_out[0, block_idx, offset, :, :] = key[i, :, :] + kv_cache_out[1, block_idx, offset, :, :] = value[i, :, :] + + return kv_cache_out diff --git a/tests/test_rms_norm.py b/tests/test_rms_norm.py index d6d4dff..ba540a9 100644 --- a/tests/test_rms_norm.py +++ b/tests/test_rms_norm.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, empty_strided, randn_strided +from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided @pytest.mark.auto_act_and_assert @@ -53,7 +53,10 @@ def test_rms_norm( def _rms_norm(input, weight, *, eps=1e-6, out=None): - infini.ops.rms_norm(input, weight, eps, out) + if input.device.type == "npu": + infini.ops.rms_norm(input, weight, eps, out, stream=get_npu_stream(input)) + else: + infini.ops.rms_norm(input, weight, eps, out) return out diff --git a/tests/test_rotary_embedding.py b/tests/test_rotary_embedding.py new file mode 100644 index 0000000..733ae43 --- /dev/null +++ b/tests/test_rotary_embedding.py @@ -0,0 +1,266 @@ +import infini.ops +import pytest +import torch + +from tests.utils import get_npu_stream, randn_strided, randint_strided + + +def _rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + query_out, + key_out, + device, +): + if device == "npu": + infini.ops.rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + query_out, + key_out, + stream=get_npu_stream(query), + ) + else: + infini.ops.rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + query_out, + key_out, + ) + + return query_out, key_out + + +def _ref_rotary_embedding( + positions, query, key, cos_sin_cache, head_size, rotary_dim, is_neox_style +): + """PyTorch reference for RoPE. + + ``cos_sin_cache`` layout: ``[max_seq_len, rotary_dim]`` where the first + ``rotary_dim // 2`` columns are cos and the rest are sin. + """ + T = query.size(0) + R = rotary_dim + half_R = R // 2 + + cos_sin = cos_sin_cache.float() + cos_half = cos_sin[:, :half_R] + sin_half = cos_sin[:, half_R:] + + def apply_rope(x): + out = x.float().clone() + + for t in range(T): + p = positions[t].item() + c = cos_half[p] + s = sin_half[p] + + if is_neox_style: + x1 = x[t, :, :half_R].float() + x2 = x[t, :, half_R:R].float() + out[t, :, :half_R] = c * x1 - s * x2 + out[t, :, half_R:R] = c * x2 + s * x1 + else: + x1 = x[t, :, 0::2].float() + x2 = x[t, :, 1::2].float() + out[t, :, 0::2] = c * x1 - s * x2 + out[t, :, 1::2] = c * x2 + s * x1 + + return out.to(x.dtype) + + return apply_rope(query), apply_rope(key) + + +def _assert_close(actual, expected, rtol, atol): + assert torch.allclose(actual, expected, rtol=rtol, atol=atol), ( + f"Max diff: {(actual.float() - expected.float()).abs().max().item()}" + ) + + +@pytest.mark.parametrize( + "num_heads, head_size", + ( + (32, 128), + (8, 64), + ), +) +@pytest.mark.parametrize("is_neox_style", (True, False)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_rotary_embedding_full( + num_heads, head_size, is_neox_style, dtype, rtol, atol, device +): + """Full rotary: ``rotary_dim == head_size``.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_kv_heads = num_heads + rotary_dim = head_size + num_tokens = 16 + max_seq_len = 64 + + positions = randint_strided( + 0, + max_seq_len, + (num_tokens,), + None, + dtype=torch.int64, + device=device, + ) + query = randn_strided( + (num_tokens, num_heads, head_size), + None, + dtype=dtype, + device=device, + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + cos_sin_cache = randn_strided( + (max_seq_len, rotary_dim), + None, + dtype=dtype, + device=device, + ) + query_out = torch.empty_like(query) + key_out = torch.empty_like(key) + + q_out, k_out = _rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + query_out, + key_out, + device, + ) + + ref_q, ref_k = _ref_rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + ) + + _assert_close(q_out, ref_q, rtol, atol) + _assert_close(k_out, ref_k, rtol, atol) + + +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, rotary_dim", + ( + (32, 8, 128, 64), + (16, 4, 64, 32), + ), +) +@pytest.mark.parametrize("is_neox_style", (True,)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_rotary_embedding_partial( + num_heads, + num_kv_heads, + head_size, + rotary_dim, + is_neox_style, + dtype, + rtol, + atol, + device, +): + """Partial rotary: ``rotary_dim < head_size``.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_tokens = 16 + max_seq_len = 64 + + positions = randint_strided( + 0, + max_seq_len, + (num_tokens,), + None, + dtype=torch.int64, + device=device, + ) + query = randn_strided( + (num_tokens, num_heads, head_size), + None, + dtype=dtype, + device=device, + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + cos_sin_cache = randn_strided( + (max_seq_len, rotary_dim), + None, + dtype=dtype, + device=device, + ) + query_out = torch.empty_like(query) + key_out = torch.empty_like(key) + + q_out, k_out = _rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + query_out, + key_out, + device, + ) + + ref_q, ref_k = _ref_rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + ) + + _assert_close(q_out, ref_q, rtol, atol) + _assert_close(k_out, ref_k, rtol, atol) diff --git a/tests/test_swiglu.py b/tests/test_swiglu.py index 89c95f7..71eaceb 100644 --- a/tests/test_swiglu.py +++ b/tests/test_swiglu.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, empty_strided, rand_strided +from tests.utils import Payload, empty_strided, get_npu_stream, rand_strided @pytest.mark.auto_act_and_assert @@ -38,7 +38,10 @@ def test_swiglu( def _swiglu(input, gate, out): - infini.ops.swiglu(input, gate, out) + if input.device.type == "npu": + infini.ops.swiglu(input, gate, out, stream=get_npu_stream(input)) + else: + infini.ops.swiglu(input, gate, out) return out From b336c84d25173007c506aa36c7e3bd613f6c07cf Mon Sep 17 00:00:00 2001 From: zhangyue Date: Wed, 8 Apr 2026 16:48:32 +0800 Subject: [PATCH 3/8] ci(ascend): update Ascend CI config, Dockerfile, and NPU detection --- .ci/README.md | 4 +- .ci/ci_resource.py | 72 ++++++++++++++++++++++++++++++++++++ .ci/config.yaml | 32 ++++++++++++++-- .ci/images/ascend/Dockerfile | 34 +++++++---------- .ci/run.py | 45 +++++----------------- .ci/tests/test_resource.py | 1 + .ci/tests/test_run.py | 33 +++++++++++++++++ 7 files changed, 159 insertions(+), 62 deletions(-) diff --git a/.ci/README.md b/.ci/README.md index 190d012..f44b5a3 100644 --- a/.ci/README.md +++ b/.ci/README.md @@ -158,7 +158,7 @@ Platform is auto-detected (via `nvidia-smi`/`ixsmi`/`mx-smi`/`mthreads-gmi`/`cnm | `--stage` | Run only the specified stage | | `--image-tag` | Override image tag | | `--gpu-id` | Override GPU device IDs (nvidia via `--gpus`, others via `CUDA_VISIBLE_DEVICES`) | -| `--test` | Override pytest test path (e.g., `tests/test_gemm.py::test_gemm`) | +| `--test` | Replace stage command entirely (e.g., `pytest tests/test_add.py -v`) | | `--results-dir` | Host directory mounted to `/workspace/results` inside the container | | `--local` | Mount current directory (read-only) instead of cloning from git | | `--dry-run` | Print docker command without executing | @@ -195,7 +195,7 @@ Proxy vars are forwarded from the host. Test results are written to `--results-d | MetaX | `--privileged` | `none` | `maca-pytorch:3.2.1.4-...` | `mx-smi` | | Moore | `--privileged` | `none` | `vllm_musa:20251112_hygon` | `mthreads-gmi` | | Cambricon | `--privileged` | `mlu` | `cambricon/pytorch:v1.25.3` | `cnmon` | -| Ascend | TODO | — | `ascend-pytorch:24.0.0` | — | +| Ascend | `--privileged` + device mounts | `npu` | `ascend-pytorch:24.0.RC3-A2-2.1.0` | `npu-smi` | `gpu_style` controls the Docker device injection mechanism: `nvidia` uses `--gpus`, `none` uses `CUDA_VISIBLE_DEVICES` (or skips injection for Moore), `mlu` uses `MLU_VISIBLE_DEVICES`. diff --git a/.ci/ci_resource.py b/.ci/ci_resource.py index 51b181f..de2953d 100644 --- a/.ci/ci_resource.py +++ b/.ci/ci_resource.py @@ -14,6 +14,7 @@ GPU_STYLE_NVIDIA = "nvidia" GPU_STYLE_NONE = "none" GPU_STYLE_MLU = "mlu" +GPU_STYLE_NPU = "npu" @dataclass @@ -44,6 +45,7 @@ class ResourcePool: "metax": "mx-smi", "moore": "mthreads-gmi", "cambricon": "cnmon", + "ascend": "npu-smi", } def __init__(self, platform, utilization_threshold=10): @@ -72,6 +74,9 @@ def detect_gpus(self) -> list[GpuInfo]: if self._platform == "cambricon": return self._detect_gpus_cambricon() + if self._platform == "ascend": + return self._detect_gpus_ascend() + tool = self.GPU_QUERY_TOOLS.get(self._platform) if not tool: @@ -325,6 +330,73 @@ def _detect_gpus_cambricon(self) -> list[GpuInfo]: return sorted(gpus, key=operator.attrgetter("index")) + def _detect_gpus_ascend(self) -> list[GpuInfo]: + """Parse npu-smi info output for Huawei Ascend NPUs. + + Output format (pipe-delimited table, two rows per NPU): + | 0 910B4 | OK | 86.5 41 ... + | 0 | 0000:C1:00.0 | 0 0 / 0 2789 / 32768 | + Row 1: index, name, health, power, temp, hugepages. + Row 2: chip_id, bus_id, aicore_util, memory_usage, hbm_usage. + """ + try: + result = subprocess.run( + ["npu-smi", "info"], + capture_output=True, + text=True, + timeout=10, + ) + except (FileNotFoundError, subprocess.TimeoutExpired): + return [] + + if result.returncode != 0: + return [] + + gpus = [] + lines = result.stdout.splitlines() + i = 0 + + while i < len(lines): + line = lines[i] + # Match row 1: "| {index} {name} ..." + m1 = re.match(r"^\|\s+(\d+)\s+", line) + + if m1 and i + 1 < len(lines): + try: + npu_index = int(m1.group(1)) + aicore_m = re.match( + r"^\|\s+\d+\s+\|\s+[\da-f:.]+\s+\|\s*([\d.]+)\s", lines[i + 1] + ) + + util_pct = float(aicore_m.group(1)) if aicore_m else 0.0 + + # Parse HBM usage from row 2: "{used} / {total}". + hbm_m = re.search(r"([\d.]+)\s*/\s*([\d.]+)", lines[i + 1]) + + if hbm_m: + used_mb = float(hbm_m.group(1)) + total_mb = float(hbm_m.group(2)) + else: + used_mb, total_mb = 0.0, 0.0 + + gpus.append( + GpuInfo( + index=npu_index, + memory_used_mb=used_mb, + memory_total_mb=total_mb, + utilization_pct=util_pct, + ) + ) + except (ValueError, AttributeError): + pass + + i += 2 + continue + + i += 1 + + return sorted(gpus, key=operator.attrgetter("index")) + def detect_system_resources(self) -> SystemResources: """Read system memory from /proc/meminfo and CPU count.""" total_mb = 0.0 diff --git a/.ci/config.yaml b/.ci/config.yaml index b70e7df..a6a5e70 100644 --- a/.ci/config.yaml +++ b/.ci/config.yaml @@ -137,10 +137,34 @@ platforms: - name: test run: pytest tests/test_gemm.py -n 4 -v --tb=short --junitxml=/workspace/results/test-results.xml - ascend: # TODO: Ascend image is not ready yet + ascend: image: dockerfile: .ci/images/ascend/ build_args: - BASE_IMAGE: ascendhub.huawei.com/public-ascendhub/ascend-pytorch:24.0.0 - private_sdk: - source_env: PRIVATE_SDK_URL + BASE_IMAGE: quay.io/ascend/vllm-ascend:v0.18.0rc1-openeuler + PIP_INDEX_URL: https://pypi.org/simple + docker_args: + - "--runtime=runc" + - "--privileged" + - "--device=/dev/davinci0" + - "--device=/dev/davinci_manager" + - "--device=/dev/devmm_svm" + - "--device=/dev/hisi_hdc" + volumes: + - /usr/local/Ascend/driver:/usr/local/Ascend/driver:ro + - /usr/local/dcmi:/usr/local/dcmi:ro + - /usr/local/bin/npu-smi:/usr/local/bin/npu-smi:ro + env: + ASCEND_HOME_PATH: /usr/local/Ascend/ascend-toolkit/latest + setup: pip install .[dev] --no-build-isolation + jobs: + npu: + resources: + gpu_ids: "0" + gpu_style: npu + memory: 32GB + shm_size: 16g + timeout: 3600 + stages: + - name: test + run: pytest tests/ -n 1 -k npu -v --tb=short --junitxml=/workspace/results/test-results.xml diff --git a/.ci/images/ascend/Dockerfile b/.ci/images/ascend/Dockerfile index 66392eb..5391d7d 100644 --- a/.ci/images/ascend/Dockerfile +++ b/.ci/images/ascend/Dockerfile @@ -1,7 +1,7 @@ ARG BASE_IMAGE FROM ${BASE_IMAGE} -ENV DEBIAN_FRONTEND=noninteractive +USER root ARG HTTP_PROXY ARG HTTPS_PROXY @@ -10,30 +10,22 @@ ARG http_proxy ARG https_proxy ARG no_proxy -RUN apt-get update && \ - apt-get install -y --no-install-recommends \ - git \ - cmake \ - ninja-build \ - coreutils \ - curl \ - libclang-dev \ - && rm -rf /var/lib/apt/lists/* - -ARG PRIVATE_SDK_URL -RUN if [ -n "$PRIVATE_SDK_URL" ]; then \ - curl -fSL "$PRIVATE_SDK_URL" -o /tmp/sdk.run && \ - chmod +x /tmp/sdk.run && /tmp/sdk.run --quiet && \ - rm /tmp/sdk.run; \ - fi - -RUN pip install --no-cache-dir \ +ARG PIP_INDEX_URL +RUN pip install --no-cache-dir --progress-bar off \ + ${PIP_INDEX_URL:+--index-url "$PIP_INDEX_URL"} \ + libclang \ + ninja \ scikit-build-core \ pybind11 \ - libclang \ pytest \ pytest-cov \ pytest-xdist \ - pyyaml + ruff==0.15.7 + +# Pin pre-installed torch to prevent pip from replacing it. +RUN pip show torch >/dev/null 2>&1 && \ + echo "torch==$(pip show torch | grep '^Version:' | awk '{print $2}')" > /etc/pip-constraints.txt || \ + touch /etc/pip-constraints.txt +ENV PIP_CONSTRAINT=/etc/pip-constraints.txt WORKDIR /workspace diff --git a/.ci/run.py b/.ci/run.py index 24a8867..092d338 100644 --- a/.ci/run.py +++ b/.ci/run.py @@ -13,47 +13,19 @@ GPU_STYLE_NVIDIA, GPU_STYLE_NONE, GPU_STYLE_MLU, + GPU_STYLE_NPU, ResourcePool, detect_platform, ) from utils import get_git_commit, load_config -# Flags that consume the next token as their value (e.g. -n 4, -k expr). -_PYTEST_VALUE_FLAGS = {"-n", "-k", "-m", "-p", "--tb", "--junitxml", "--rootdir"} +def apply_test_override(run_cmd, test_cmd): + """Replace a stage command with *test_cmd*. - -def apply_test_override(run_cmd, test_path): - """Replace positional test path(s) in a pytest stage command. - - For example: ``pytest tests/ -n 4 ...`` becomes - ``pytest tests/test_gemm.py -n 4 ...`` when ``test_path`` is - ``tests/test_gemm.py``. + ``--test`` always replaces the entire stage command regardless of whether + the original is pytest or something else. """ - parts = shlex.split(run_cmd) - - if not parts or parts[0] != "pytest": - return run_cmd - - result = ["pytest", test_path] - skip_next = False - - for p in parts[1:]: - if skip_next: - result.append(p) - skip_next = False - continue - - if p.startswith("-"): - result.append(p) - if p in _PYTEST_VALUE_FLAGS: - skip_next = True - continue - - # Skip existing test paths; the override is already in result[1]. - if not ("/" in p or p.endswith(".py") or "::" in p): - result.append(p) - - return shlex.join(result) + return test_cmd def build_results_dir(base, platform, stages, commit): @@ -212,6 +184,9 @@ def build_docker_args( # For Cambricon MLU platforms that use --privileged, # control visible devices via MLU_VISIBLE_DEVICES. args.extend(["-e", f"MLU_VISIBLE_DEVICES={gpu_id}"]) + elif gpu_style == GPU_STYLE_NPU and gpu_id and gpu_id != "all": + # Ascend: control visible NPU via ASCEND_VISIBLE_DEVICES. + args.extend(["-e", f"ASCEND_VISIBLE_DEVICES={gpu_id}"]) memory = resources.get("memory") @@ -315,7 +290,7 @@ def main(): parser.add_argument( "--test", type=str, - help='Override pytest test path, e.g. "tests/test_gemm.py" or "tests/test_gemm.py::test_gemm"', + help='Replace stage command with this (e.g. "pytest tests/test_add.py -v")', ) parser.add_argument( "--local", diff --git a/.ci/tests/test_resource.py b/.ci/tests/test_resource.py index cbe37d8..0db3fbb 100644 --- a/.ci/tests/test_resource.py +++ b/.ci/tests/test_resource.py @@ -93,6 +93,7 @@ def test_detect_system_resources(monkeypatch, tmp_path): "MemAvailable: 20000000 kB\n" ) + _real_open = open def fake_open(path, **kw): diff --git a/.ci/tests/test_run.py b/.ci/tests/test_run.py index 93987e5..65c6de6 100644 --- a/.ci/tests/test_run.py +++ b/.ci/tests/test_run.py @@ -296,3 +296,36 @@ def test_build_results_dir_under_base(): stages = [{"name": "test", "run": "pytest"}] d = run.build_results_dir("/tmp/my-results", "ascend", stages, "def5678") assert d.parent == Path("/tmp/my-results") + + +# --------------------------------------------------------------------------- +# Tests for `apply_test_override`. +# --------------------------------------------------------------------------- + + +def test_apply_test_override_replaces_pytest_command(): + assert run.apply_test_override("pytest tests/ -v", "pytest tests/test_add.py") == ( + "pytest tests/test_add.py" + ) + + +def test_apply_test_override_replaces_non_pytest_command(): + assert run.apply_test_override("ruff check .", "python docs/repro.py") == ( + "python docs/repro.py" + ) + + +def test_apply_test_override_replaces_empty_command(): + assert run.apply_test_override("", "bash script.sh") == "bash script.sh" + + +def test_apply_test_override_preserves_user_flags(): + cmd = "pytest tests/test_gemm.py -n 1 -v --tb=short" + assert run.apply_test_override("pytest tests/ -n 4", cmd) == cmd + + +def test_apply_test_override_with_shell_command(): + assert ( + run.apply_test_override("pytest tests/", "cd /tmp && python repro.py") + == "cd /tmp && python repro.py" + ) From 26c2bdc5837c98ef4a58b13a1f3ef336ddee60d9 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Wed, 8 Apr 2026 16:48:38 +0800 Subject: [PATCH 4/8] docs: add Ascend FlashAttention design spec --- ...026-03-30-ascend-flash-attention-design.md | 225 ++++++++++++++++++ 1 file changed, 225 insertions(+) create mode 100644 docs/superpowers/specs/2026-03-30-ascend-flash-attention-design.md diff --git a/docs/superpowers/specs/2026-03-30-ascend-flash-attention-design.md b/docs/superpowers/specs/2026-03-30-ascend-flash-attention-design.md new file mode 100644 index 0000000..c07012f --- /dev/null +++ b/docs/superpowers/specs/2026-03-30-ascend-flash-attention-design.md @@ -0,0 +1,225 @@ +# Ascend Flash Attention & Reshape-And-Cache Design + +**Date:** 2026-03-30 +**Status:** Approved +**Scope:** Two new operators for the Ascend backend, compatible with vLLM input layout conventions. + +## Overview + +Add `FlashAttention` and `ReshapeAndCache` operators to InfiniOps targeting the Ascend NPU backend. The operators wrap CANN's `aclnnFusedInferAttentionScore` (FIA) API and accept vLLM-compatible TND (token-major) tensor layouts, enabling direct integration with vLLM's attention pipeline. + +## Operator 1: FlashAttention + +### Interface + +```cpp +// src/base/flash_attention.h +class FlashAttention : public Operator { + public: + FlashAttention( + const Tensor query, // [num_tokens, num_heads, head_size] TND + const Tensor key, // TND or paged cache [num_blocks, KV_N, block_size, D] + const Tensor value, + std::optional block_table, // [num_reqs, max_blocks_per_req], INT32 + std::optional cu_seqlens_q, // [num_reqs + 1], INT64 + std::optional cu_seqlens_kv,// [num_reqs + 1], INT64 + int64_t num_heads, + int64_t num_kv_heads, + int64_t head_size, + double scale, // 1/sqrt(head_size) + int64_t sparse_mode, // 3 = causal (right-down triangular) + int64_t block_size, // 0 = no paging, else 128/256/384/512 + Tensor output // [num_tokens, num_heads, head_size] + ); + + virtual void operator()( + const Tensor query, const Tensor key, const Tensor value, + std::optional block_table, + std::optional cu_seqlens_q, + std::optional cu_seqlens_kv, + int64_t num_heads, int64_t num_kv_heads, int64_t head_size, + double scale, int64_t sparse_mode, int64_t block_size, + Tensor output + ) const = 0; +}; +``` + +### Tensor Layout + +All tensors use TND (token-major) layout to match vLLM conventions: + +| Tensor | Shape | Dtype | Notes | +|--------|-------|-------|-------| +| `query` | `[num_tokens, num_heads, head_size]` | fp16/bf16 | Concatenated query tokens | +| `key` | `[num_tokens, num_kv_heads, head_size]` or `[num_blocks, KV_N, block_size, D]` | fp16/bf16 | Input K or paged cache | +| `value` | Same shape as `key` | fp16/bf16 | Input V or paged cache | +| `output` | `[num_tokens, num_heads, head_size]` | fp16/bf16 | Attention output | +| `block_table` | `[num_reqs, max_blocks_per_req]` | INT32 | Paged KV cache block mapping | +| `cu_seqlens_q` | `[num_reqs + 1]` | INT64 | Cumulative query sequence lengths | +| `cu_seqlens_kv` | `[num_reqs + 1]` | INT64 | Cumulative KV sequence lengths | + +### ACLNN FIA Mapping + +The Ascend backend (`src/ascend/flash_attention/kernel.h`) wraps `aclnnFusedInferAttentionScore`: + +| InfiniOps | ACLNN FIA | Notes | +|-----------|-----------|-------| +| `query [T,N,D]` | `query` as `[1,N,T,D]` BNSD | Reshape (view, no copy) | +| `key` (paged cache) | `key` as `aclTensorList*` | Single-element list pointing to cache | +| `value` (paged cache) | `value` as `aclTensorList*` | Same as key | +| `block_table` | `blockTable` | Direct pass-through | +| `cu_seqlens_q` | `actualSeqLengths` | Extract to host `aclIntArray*` | +| `cu_seqlens_kv` | `actualSeqLengthsKv` | Extract to host `aclIntArray*` | +| `num_heads` | `numHeads` | | +| `num_kv_heads` | `numKeyValueHeads` | Supports GQA | +| `scale` | `scaleValue` | | +| `sparse_mode` | `sparseMode` | 3 = causal | +| `block_size` | `blockSize` | | +| `output [T,N,D]` | `attentionOut` as `[1,N,T,D]` | Reshape back | + +**Internal defaults (not exposed):** + +- `inputLayout` = `"BNSD"` +- `pseShift` = nullptr (no position encoding shift) +- `attenMask` = nullptr (causal handled by `sparseMode=3`) +- `preTokens` / `nextTokens` = `2147483647` (INT_MAX) +- `innerPrecise` = 0 (high precision mode) +- `softmaxLseFlag` = false +- All quantization parameters = nullptr + +### Workflow + +1. Reshape TND input tensors to BNSD views (no memory copy) +2. Extract `cu_seqlens_q`/`cu_seqlens_kv` to host-side `aclIntArray*` +3. Build ACL tensor descriptors via `ascend::buildAclTensor()` +4. Create `aclTensorList*` for key/value (single-element list wrapping the cache tensor) +5. Call `aclnnFusedInferAttentionScoreGetWorkspaceSize` +6. Allocate workspace via `WorkspacePool::ensure()` +7. Call `aclnnFusedInferAttentionScore` +8. Destroy all ACL descriptors + +### Constraints + +- **Dtypes:** float16, bfloat16 only +- **head_size:** must be 16-byte aligned (multiple of 8 for fp16, 4 for bf16), max 512 +- **num_heads:** max 256 +- **block_size:** 128, 256, 384, or 512 (multiple of 128). 0 disables paging +- **KV cache format:** `(num_blocks, KV_N, block_size, D)` preferred (better performance than `(num_blocks, block_size, H)`) +- **GQA:** `num_heads % num_kv_heads == 0`, ratio <= 64 +- **Paged attention requires:** `block_table` present, `cu_seqlens_kv` provided, `block_size >= 128` + +## Operator 2: ReshapeAndCache + +### Interface + +```cpp +// src/base/reshape_and_cache.h +class ReshapeAndCache : public Operator { + public: + ReshapeAndCache( + const Tensor key, // [num_tokens, num_kv_heads, head_size] + const Tensor value, // [num_tokens, num_kv_heads, head_size] + const Tensor kv_cache, // [num_blocks, block_size, num_kv_heads, head_size] + const Tensor slot_mapping, // [num_tokens], INT64 + Tensor kv_cache_out // same shape as kv_cache (in-place) + ); + + virtual void operator()( + const Tensor key, const Tensor value, + const Tensor kv_cache, const Tensor slot_mapping, + Tensor kv_cache_out + ) const = 0; +}; +``` + +### Behavior + +Scatter-writes new key/value tokens into the paged KV cache. For each token `i`: + +``` +slot = slot_mapping[i] +block_idx = slot // block_size +offset = slot % block_size +kv_cache_out[block_idx, offset, :, :] = key[i, :, :] +``` + +### Implementation + +Start with `aclrtMemcpy`-based element-wise copy with stride arithmetic (no custom AscendC kernel). Optimize later if profiling shows this is a bottleneck. + +## File Structure + +``` +src/base/flash_attention.h # Abstract base class +src/base/reshape_and_cache.h # Abstract base class +src/ascend/flash_attention/kernel.h # Ascend specialization +src/ascend/reshape_and_cache/kernel.h # Ascend specialization +tests/test_flash_attention.py # Operator tests +tests/test_reshape_and_cache.py # Operator tests +``` + +## Testing Strategy + +### FlashAttention Tests + +Tests follow the `Payload` / `auto_act_and_assert` pattern from `conftest.py`: + +- **Prefill (no block table):** single sequence, multi-sequence with `cu_seqlens` +- **Decode (with block table):** single token per request with paged KV cache +- **GQA:** `num_kv_heads < num_heads` +- **Causal masking:** `sparse_mode=3` +- **Dtypes:** fp16, bf16 (skipped on Ascend for unsupported dtypes) +- **Reference:** PyTorch `scaled_dot_product_attention` with causal mask + +### ReshapeAndCache Tests + +- Write single token into empty paged cache, verify correct slot placement +- Write batch of tokens with contiguous slot mapping +- Write batch with non-contiguous slot mapping (holes in cache) +- **Reference:** manual scatter via NumPy indexing + +### Device Filtering + +Tests use `device="npu"` parametrization. Use `-k "not cpu"` to select Ascend tests (avoids substring match with "input"). + +## Python Bindings + +Auto-generated by `scripts/generate_wrappers.py`. Usage: + +```python +import infini + +# Free function +out = infini.ops.flash_attention( + query, key, value, + block_table=block_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + num_heads=32, num_kv_heads=8, head_size=128, + scale=1.0/128**0.5, sparse_mode=3, block_size=128, + output=out +) + +# ReshapeAndCache +infini.ops.reshape_and_cache(key, value, kv_cache, slot_mapping, kv_cache) +``` + +## Decisions Log + +| Decision | Choice | Rationale | +|----------|--------|-----------| +| ACLNN API | `aclnnFusedInferAttentionScore` (FIA) | Single API for prefill + decode, matches vllm-ascend's primary path | +| Tensor layout | Accept TND, reshape to BNSD internally | Matches vLLM conventions, simpler Python adapter | +| Operator scope | FlashAttention + ReshapeAndCache | Covers full vLLM attention pipeline: cache write + attention computation | +| Quantization | Not exposed in initial version | YAGNI — can add quantization params later | +| ReshapeAndCache impl | `aclrtMemcpy` with strides | Simplest, no custom kernel. Optimize after profiling. | +| KV cache format | `(num_blocks, KV_N, block_size, D)` | Better performance per ACLNN docs | + +## Out of Scope + +- MLA (Multi-head Latent Attention) support +- Quantized attention (INT8 input/output) +- Custom AscendC kernels for hot-path optimization +- Full vLLM `AttentionBackend` implementation +- Speculative decoding support +- Sparse Flash Attention (DSA) From ad8bf0691ebff29b6dbad8af982e065fb2554b86 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Wed, 8 Apr 2026 17:28:23 +0800 Subject: [PATCH 5/8] Revert "docs: add Ascend FlashAttention design spec" This reverts commit 26c2bdc5837c98ef4a58b13a1f3ef336ddee60d9. --- ...026-03-30-ascend-flash-attention-design.md | 225 ------------------ 1 file changed, 225 deletions(-) delete mode 100644 docs/superpowers/specs/2026-03-30-ascend-flash-attention-design.md diff --git a/docs/superpowers/specs/2026-03-30-ascend-flash-attention-design.md b/docs/superpowers/specs/2026-03-30-ascend-flash-attention-design.md deleted file mode 100644 index c07012f..0000000 --- a/docs/superpowers/specs/2026-03-30-ascend-flash-attention-design.md +++ /dev/null @@ -1,225 +0,0 @@ -# Ascend Flash Attention & Reshape-And-Cache Design - -**Date:** 2026-03-30 -**Status:** Approved -**Scope:** Two new operators for the Ascend backend, compatible with vLLM input layout conventions. - -## Overview - -Add `FlashAttention` and `ReshapeAndCache` operators to InfiniOps targeting the Ascend NPU backend. The operators wrap CANN's `aclnnFusedInferAttentionScore` (FIA) API and accept vLLM-compatible TND (token-major) tensor layouts, enabling direct integration with vLLM's attention pipeline. - -## Operator 1: FlashAttention - -### Interface - -```cpp -// src/base/flash_attention.h -class FlashAttention : public Operator { - public: - FlashAttention( - const Tensor query, // [num_tokens, num_heads, head_size] TND - const Tensor key, // TND or paged cache [num_blocks, KV_N, block_size, D] - const Tensor value, - std::optional block_table, // [num_reqs, max_blocks_per_req], INT32 - std::optional cu_seqlens_q, // [num_reqs + 1], INT64 - std::optional cu_seqlens_kv,// [num_reqs + 1], INT64 - int64_t num_heads, - int64_t num_kv_heads, - int64_t head_size, - double scale, // 1/sqrt(head_size) - int64_t sparse_mode, // 3 = causal (right-down triangular) - int64_t block_size, // 0 = no paging, else 128/256/384/512 - Tensor output // [num_tokens, num_heads, head_size] - ); - - virtual void operator()( - const Tensor query, const Tensor key, const Tensor value, - std::optional block_table, - std::optional cu_seqlens_q, - std::optional cu_seqlens_kv, - int64_t num_heads, int64_t num_kv_heads, int64_t head_size, - double scale, int64_t sparse_mode, int64_t block_size, - Tensor output - ) const = 0; -}; -``` - -### Tensor Layout - -All tensors use TND (token-major) layout to match vLLM conventions: - -| Tensor | Shape | Dtype | Notes | -|--------|-------|-------|-------| -| `query` | `[num_tokens, num_heads, head_size]` | fp16/bf16 | Concatenated query tokens | -| `key` | `[num_tokens, num_kv_heads, head_size]` or `[num_blocks, KV_N, block_size, D]` | fp16/bf16 | Input K or paged cache | -| `value` | Same shape as `key` | fp16/bf16 | Input V or paged cache | -| `output` | `[num_tokens, num_heads, head_size]` | fp16/bf16 | Attention output | -| `block_table` | `[num_reqs, max_blocks_per_req]` | INT32 | Paged KV cache block mapping | -| `cu_seqlens_q` | `[num_reqs + 1]` | INT64 | Cumulative query sequence lengths | -| `cu_seqlens_kv` | `[num_reqs + 1]` | INT64 | Cumulative KV sequence lengths | - -### ACLNN FIA Mapping - -The Ascend backend (`src/ascend/flash_attention/kernel.h`) wraps `aclnnFusedInferAttentionScore`: - -| InfiniOps | ACLNN FIA | Notes | -|-----------|-----------|-------| -| `query [T,N,D]` | `query` as `[1,N,T,D]` BNSD | Reshape (view, no copy) | -| `key` (paged cache) | `key` as `aclTensorList*` | Single-element list pointing to cache | -| `value` (paged cache) | `value` as `aclTensorList*` | Same as key | -| `block_table` | `blockTable` | Direct pass-through | -| `cu_seqlens_q` | `actualSeqLengths` | Extract to host `aclIntArray*` | -| `cu_seqlens_kv` | `actualSeqLengthsKv` | Extract to host `aclIntArray*` | -| `num_heads` | `numHeads` | | -| `num_kv_heads` | `numKeyValueHeads` | Supports GQA | -| `scale` | `scaleValue` | | -| `sparse_mode` | `sparseMode` | 3 = causal | -| `block_size` | `blockSize` | | -| `output [T,N,D]` | `attentionOut` as `[1,N,T,D]` | Reshape back | - -**Internal defaults (not exposed):** - -- `inputLayout` = `"BNSD"` -- `pseShift` = nullptr (no position encoding shift) -- `attenMask` = nullptr (causal handled by `sparseMode=3`) -- `preTokens` / `nextTokens` = `2147483647` (INT_MAX) -- `innerPrecise` = 0 (high precision mode) -- `softmaxLseFlag` = false -- All quantization parameters = nullptr - -### Workflow - -1. Reshape TND input tensors to BNSD views (no memory copy) -2. Extract `cu_seqlens_q`/`cu_seqlens_kv` to host-side `aclIntArray*` -3. Build ACL tensor descriptors via `ascend::buildAclTensor()` -4. Create `aclTensorList*` for key/value (single-element list wrapping the cache tensor) -5. Call `aclnnFusedInferAttentionScoreGetWorkspaceSize` -6. Allocate workspace via `WorkspacePool::ensure()` -7. Call `aclnnFusedInferAttentionScore` -8. Destroy all ACL descriptors - -### Constraints - -- **Dtypes:** float16, bfloat16 only -- **head_size:** must be 16-byte aligned (multiple of 8 for fp16, 4 for bf16), max 512 -- **num_heads:** max 256 -- **block_size:** 128, 256, 384, or 512 (multiple of 128). 0 disables paging -- **KV cache format:** `(num_blocks, KV_N, block_size, D)` preferred (better performance than `(num_blocks, block_size, H)`) -- **GQA:** `num_heads % num_kv_heads == 0`, ratio <= 64 -- **Paged attention requires:** `block_table` present, `cu_seqlens_kv` provided, `block_size >= 128` - -## Operator 2: ReshapeAndCache - -### Interface - -```cpp -// src/base/reshape_and_cache.h -class ReshapeAndCache : public Operator { - public: - ReshapeAndCache( - const Tensor key, // [num_tokens, num_kv_heads, head_size] - const Tensor value, // [num_tokens, num_kv_heads, head_size] - const Tensor kv_cache, // [num_blocks, block_size, num_kv_heads, head_size] - const Tensor slot_mapping, // [num_tokens], INT64 - Tensor kv_cache_out // same shape as kv_cache (in-place) - ); - - virtual void operator()( - const Tensor key, const Tensor value, - const Tensor kv_cache, const Tensor slot_mapping, - Tensor kv_cache_out - ) const = 0; -}; -``` - -### Behavior - -Scatter-writes new key/value tokens into the paged KV cache. For each token `i`: - -``` -slot = slot_mapping[i] -block_idx = slot // block_size -offset = slot % block_size -kv_cache_out[block_idx, offset, :, :] = key[i, :, :] -``` - -### Implementation - -Start with `aclrtMemcpy`-based element-wise copy with stride arithmetic (no custom AscendC kernel). Optimize later if profiling shows this is a bottleneck. - -## File Structure - -``` -src/base/flash_attention.h # Abstract base class -src/base/reshape_and_cache.h # Abstract base class -src/ascend/flash_attention/kernel.h # Ascend specialization -src/ascend/reshape_and_cache/kernel.h # Ascend specialization -tests/test_flash_attention.py # Operator tests -tests/test_reshape_and_cache.py # Operator tests -``` - -## Testing Strategy - -### FlashAttention Tests - -Tests follow the `Payload` / `auto_act_and_assert` pattern from `conftest.py`: - -- **Prefill (no block table):** single sequence, multi-sequence with `cu_seqlens` -- **Decode (with block table):** single token per request with paged KV cache -- **GQA:** `num_kv_heads < num_heads` -- **Causal masking:** `sparse_mode=3` -- **Dtypes:** fp16, bf16 (skipped on Ascend for unsupported dtypes) -- **Reference:** PyTorch `scaled_dot_product_attention` with causal mask - -### ReshapeAndCache Tests - -- Write single token into empty paged cache, verify correct slot placement -- Write batch of tokens with contiguous slot mapping -- Write batch with non-contiguous slot mapping (holes in cache) -- **Reference:** manual scatter via NumPy indexing - -### Device Filtering - -Tests use `device="npu"` parametrization. Use `-k "not cpu"` to select Ascend tests (avoids substring match with "input"). - -## Python Bindings - -Auto-generated by `scripts/generate_wrappers.py`. Usage: - -```python -import infini - -# Free function -out = infini.ops.flash_attention( - query, key, value, - block_table=block_table, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - num_heads=32, num_kv_heads=8, head_size=128, - scale=1.0/128**0.5, sparse_mode=3, block_size=128, - output=out -) - -# ReshapeAndCache -infini.ops.reshape_and_cache(key, value, kv_cache, slot_mapping, kv_cache) -``` - -## Decisions Log - -| Decision | Choice | Rationale | -|----------|--------|-----------| -| ACLNN API | `aclnnFusedInferAttentionScore` (FIA) | Single API for prefill + decode, matches vllm-ascend's primary path | -| Tensor layout | Accept TND, reshape to BNSD internally | Matches vLLM conventions, simpler Python adapter | -| Operator scope | FlashAttention + ReshapeAndCache | Covers full vLLM attention pipeline: cache write + attention computation | -| Quantization | Not exposed in initial version | YAGNI — can add quantization params later | -| ReshapeAndCache impl | `aclrtMemcpy` with strides | Simplest, no custom kernel. Optimize after profiling. | -| KV cache format | `(num_blocks, KV_N, block_size, D)` | Better performance per ACLNN docs | - -## Out of Scope - -- MLA (Multi-head Latent Attention) support -- Quantized attention (INT8 input/output) -- Custom AscendC kernels for hot-path optimization -- Full vLLM `AttentionBackend` implementation -- Speculative decoding support -- Sparse Flash Attention (DSA) From 3a56aa8a81430a98bceef8c725be2dd4dd406e60 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Thu, 9 Apr 2026 16:45:10 +0800 Subject: [PATCH 6/8] feat(ascend): optimize all operator dispatch (P0-P4) and add Cast/Cat/Linear/Mul operators Descriptor caching (`AclTensorCache` + `aclSetRawTensorAddr`), executor caching (`aclSetAclOpExecutorRepeatable`), D2H sync elimination, `add_rms_norm` decomposition, and `WorkspacePool` thread-local fast path. Host dispatch dropped from ~255 us/call to 17-57 us/call for all cacheable operators. New operators: Cast (`aclnnCast`), Cat (`aclnnCat` with TensorList executor caching), Linear (`aclnnAddmm`/`aclnnBaddbmm`/ `aclnnMatmul`), Mul (`aclnnMul`). Full regression: 2040 passed, 0 failed. --- scripts/generate_wrappers.py | 22 + src/ascend/add/kernel.h | 49 +- src/ascend/add_rms_norm/kernel.h | 117 +++-- src/ascend/add_rms_norm/kernel_fused.h | 124 +++++ src/ascend/add_rms_norm/registry.h | 15 + src/ascend/cast/kernel.h | 60 +++ src/ascend/cat/kernel.h | 91 ++++ src/ascend/causal_softmax/kernel.h | 91 ++-- src/ascend/common.h | 119 +++++ src/ascend/flash_attention/kernel.h | 234 ++++++---- src/ascend/gemm/kernel.h | 66 ++- src/ascend/linear/kernel.h | 122 +++++ src/ascend/matmul/kernel.h | 53 ++- src/ascend/mul/kernel.h | 63 +++ src/ascend/reshape_and_cache/kernel.h | 123 +++-- src/ascend/rms_norm/kernel.h | 61 ++- src/ascend/rotary_embedding/kernel.h | 612 ++++++++----------------- src/ascend/swiglu/kernel.h | 81 ++-- src/ascend/workspace_pool_.h | 55 ++- src/base/cast.h | 52 +++ src/base/cat.h | 34 ++ src/base/linear.h | 64 +++ src/base/mul.h | 67 +++ src/cpu/cast/cast.h | 57 +++ src/cpu/cat/cat.h | 68 +++ src/cpu/linear/linear.h | 112 +++++ src/cpu/mul/mul.h | 63 +++ src/hash.h | 9 + src/operator.h | 8 + src/pybind11_utils.h | 12 +- tests/test_add_rms_norm.py | 95 ++++ tests/test_cast.py | 65 +++ tests/test_cat.py | 72 +++ tests/test_linear.py | 95 ++++ tests/test_matmul.py | 79 ++++ tests/test_mul.py | 90 ++++ tests/test_rotary_embedding.py | 15 + 37 files changed, 2502 insertions(+), 713 deletions(-) create mode 100644 src/ascend/add_rms_norm/kernel_fused.h create mode 100644 src/ascend/add_rms_norm/registry.h create mode 100644 src/ascend/cast/kernel.h create mode 100644 src/ascend/cat/kernel.h create mode 100644 src/ascend/linear/kernel.h create mode 100644 src/ascend/mul/kernel.h create mode 100644 src/base/cast.h create mode 100644 src/base/cat.h create mode 100644 src/base/linear.h create mode 100644 src/base/mul.h create mode 100644 src/cpu/cast/cast.h create mode 100644 src/cpu/cat/cat.h create mode 100644 src/cpu/linear/linear.h create mode 100644 src/cpu/mul/mul.h create mode 100644 tests/test_add_rms_norm.py create mode 100644 tests/test_cast.py create mode 100644 tests/test_cat.py create mode 100644 tests/test_linear.py create mode 100644 tests/test_matmul.py create mode 100644 tests/test_mul.py diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index fc8f1bf..18c61cb 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -103,14 +103,30 @@ def _find_optional_tensor_params(op_name): return set(re.findall(r"std::optional\s+(\w+)", source)) +def _find_vector_tensor_params(op_name): + """Return a set of parameter names declared as `std::vector` in + the base header. + """ + import re + + source = (_BASE_DIR / f"{op_name}.h").read_text() + return set(re.findall(r"std::vector\s+(\w+)", source)) + + def _generate_pybind11(operator): optional_tensor_params = _find_optional_tensor_params(operator.name) + vector_tensor_params = _find_vector_tensor_params(operator.name) def _is_optional_tensor(arg): if arg.spelling in optional_tensor_params: return True return "std::optional" in arg.type.spelling and "Tensor" in arg.type.spelling + def _is_vector_tensor(arg): + if arg.spelling in vector_tensor_params: + return True + return "std::vector" in arg.type.spelling and "Tensor" in arg.type.spelling + def _generate_params(node): parts = [] for arg in node.get_arguments(): @@ -118,6 +134,8 @@ def _generate_params(node): continue if _is_optional_tensor(arg): parts.append(f"std::optional {arg.spelling}") + elif _is_vector_tensor(arg): + parts.append(f"std::vector {arg.spelling}") else: param = ( arg.type.spelling @@ -136,6 +154,10 @@ def _generate_arguments(node): args.append( f"OptionalTensorFromPybind11Handle({arg.spelling})" ) + elif _is_vector_tensor(arg): + args.append( + f"VectorTensorFromPybind11Handle({arg.spelling})" + ) elif "Tensor" in arg.type.spelling: args.append(f"TensorFromPybind11Handle({arg.spelling})") else: diff --git a/src/ascend/add/kernel.h b/src/ascend/add/kernel.h index e81f9bd..650edeb 100644 --- a/src/ascend/add/kernel.h +++ b/src/ascend/add/kernel.h @@ -16,7 +16,10 @@ template <> class Operator : public Add { public: Operator(const Tensor input, const Tensor other, Tensor out) - : Add(input, other, out) { + : Add(input, other, out), + in_cache_(input), + oth_cache_(other), + out_cache_(out) { // aclCreateScalar stores the pointer rather than copying the value, so // alpha_storage_* must remain alive for the lifetime of alpha_. // The alpha scalar type must match the tensor dtype: use int64 for integer @@ -28,25 +31,45 @@ class Operator : public Add { } } - ~Operator() { aclDestroyScalar(alpha_); } + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + aclDestroyScalar(alpha_); + } void operator()(const Tensor input, const Tensor other, Tensor out) const override { auto stream = static_cast(stream_); - auto t_in = ascend::buildAclTensor(input); - auto t_oth = ascend::buildAclTensor(other); - auto t_out = ascend::buildAclTensor(out); - uint64_t ws_needed = 0; - aclOpExecutor* executor = nullptr; - aclnnAddGetWorkspaceSize(t_in, t_oth, alpha_, t_out, &ws_needed, &executor); - auto& arena = ascend::workspacePool().ensure(stream, ws_needed); - aclnnAdd(arena.buf, ws_needed, executor, stream); - aclDestroyTensor(t_in); - aclDestroyTensor(t_oth); - aclDestroyTensor(t_out); + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_oth = oth_cache_.get(const_cast(other.data())); + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + aclnnAddGetWorkspaceSize(t_in, t_oth, alpha_, t_out, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_in, + const_cast(input.data())); + aclSetInputTensorAddr(executor_, 1, t_oth, + const_cast(other.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnAdd(arena.buf, ws_size_, executor_, stream); } private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache oth_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; + float alpha_float_storage_ = 1.0f; // stable address for aclCreateScalar (float) int64_t alpha_int_storage_ = 1; // stable address for aclCreateScalar (int) diff --git a/src/ascend/add_rms_norm/kernel.h b/src/ascend/add_rms_norm/kernel.h index 28ae702..4f9670a 100644 --- a/src/ascend/add_rms_norm/kernel.h +++ b/src/ascend/add_rms_norm/kernel.h @@ -5,58 +5,121 @@ #include "acl/acl.h" #include "aclnn/aclnn_base.h" -#include "aclnn_add_rms_norm.h" +#include "aclnn_add.h" +#include "aclnn_rms_norm.h" #include "ascend/common.h" +#include "ascend/add_rms_norm/registry.h" #include "ascend/workspace_pool_.h" -#include "base/add_rms_norm.h" #include "operator.h" namespace infini::ops { +// Decomposed implementation: aclnnAdd + aclnnRmsNorm. +// +// The fused aclnnAddRmsNorm API has ~200 us host-side launch overhead that +// dominates small-tensor dispatch. Decomposing into two fast ACLNN calls +// reduces host dispatch from ~224 us to ~56 us (4x faster) with negligible +// NPU-side impact for inference tensor sizes. template <> -class Operator : public AddRmsNorm { +class Operator : public AddRmsNorm { public: Operator(const Tensor x1, const Tensor x2, const Tensor gamma, float eps, Tensor y_out, Tensor x_out) - : AddRmsNorm(x1, x2, gamma, eps, y_out, x_out) { - // aclnnAddRmsNorm writes rstd as a required side output. - // Allocate a persistent device buffer for it. + : AddRmsNorm(x1, x2, gamma, eps, y_out, x_out), + x1_cache_(x1), + x2_cache_(x2), + gamma_cache_(gamma), + y_out_cache_(y_out), + x_out_cache_(x_out) { + // Alpha scalar for aclnnAdd (x_out = x1 + 1.0 * x2). + alpha_ = aclCreateScalar(&alpha_storage_, ACL_FLOAT); + + // aclnnRmsNorm writes rstd as a required side output. + rstd_shape_ = {static_cast(batch_size_), + static_cast(nhead_)}; size_t rstd_bytes = batch_size_ * nhead_ * sizeof(float); aclrtMalloc(&rstd_data_, rstd_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + rstd_tensor_ = aclCreateTensor(rstd_shape_.data(), 2, ACL_FLOAT, + /*strides=*/nullptr, 0, ACL_FORMAT_ND, + rstd_shape_.data(), 2, rstd_data_); } ~Operator() { + if (add_exec_) aclDestroyAclOpExecutor(add_exec_); + if (norm_exec_) aclDestroyAclOpExecutor(norm_exec_); + aclDestroyScalar(alpha_); + if (rstd_tensor_) aclDestroyTensor(rstd_tensor_); if (rstd_data_) aclrtFree(rstd_data_); } void operator()(const Tensor x1, const Tensor x2, const Tensor gamma, float eps, Tensor y_out, Tensor x_out) const override { - auto t_x1 = ascend::buildAclTensor(x1); - auto t_x2 = ascend::buildAclTensor(x2); - auto t_gamma = ascend::buildAclTensor(gamma); - auto t_y_out = ascend::buildAclTensor(y_out); - auto t_x_out = ascend::buildAclTensor(x_out); - // rstd is always float32 regardless of input dtype. - auto t_rstd = aclCreateTensor(rstd_shape_.data(), 2, ACL_FLOAT, - /*strides=*/nullptr, 0, ACL_FORMAT_ND, - rstd_shape_.data(), 2, rstd_data_); - uint64_t ws_needed = 0; - aclOpExecutor* executor = nullptr; - aclnnAddRmsNormGetWorkspaceSize(t_x1, t_x2, t_gamma, eps, t_y_out, t_rstd, - t_x_out, &ws_needed, &executor); + auto t_x1 = x1_cache_.get(const_cast(x1.data())); + auto t_x2 = x2_cache_.get(const_cast(x2.data())); + auto t_gamma = gamma_cache_.get(const_cast(gamma.data())); + auto t_y_out = y_out_cache_.get(y_out.data()); + auto t_x_out = x_out_cache_.get(x_out.data()); auto stream = static_cast(stream_); - auto& arena = ascend::workspacePool().ensure(stream, ws_needed); - aclnnAddRmsNorm(arena.buf, ws_needed, executor, stream); - aclDestroyTensor(t_x1); - aclDestroyTensor(t_x2); - aclDestroyTensor(t_gamma); - aclDestroyTensor(t_y_out); - aclDestroyTensor(t_rstd); - aclDestroyTensor(t_x_out); + + // Step 1: x_out = x1 + x2. + if (!add_exec_) { + aclnnAddGetWorkspaceSize(t_x1, t_x2, alpha_, t_x_out, &add_ws_, + &add_exec_); + aclSetAclOpExecutorRepeatable(add_exec_); + } else { + aclSetInputTensorAddr(add_exec_, 0, t_x1, + const_cast(x1.data())); + aclSetInputTensorAddr(add_exec_, 1, t_x2, + const_cast(x2.data())); + aclSetOutputTensorAddr(add_exec_, 0, t_x_out, x_out.data()); + } + auto& add_arena = ascend::workspacePool().ensure(stream, add_ws_); + aclnnAdd(add_arena.buf, add_ws_, add_exec_, stream); + + // Step 2: y_out = rms_norm(x_out, gamma, eps). + if (!norm_exec_) { + aclnnRmsNormGetWorkspaceSize(t_x_out, t_gamma, eps, t_y_out, + rstd_tensor_, &norm_ws_, &norm_exec_); + aclSetAclOpExecutorRepeatable(norm_exec_); + } else { + aclSetInputTensorAddr(norm_exec_, 0, t_x_out, x_out.data()); + aclSetInputTensorAddr(norm_exec_, 1, t_gamma, + const_cast(gamma.data())); + aclSetOutputTensorAddr(norm_exec_, 0, t_y_out, y_out.data()); + } + auto& norm_arena = ascend::workspacePool().ensure(stream, norm_ws_); + aclnnRmsNorm(norm_arena.buf, norm_ws_, norm_exec_, stream); } private: + mutable ascend::AclTensorCache x1_cache_; + + mutable ascend::AclTensorCache x2_cache_; + + mutable ascend::AclTensorCache gamma_cache_; + + mutable ascend::AclTensorCache y_out_cache_; + + mutable ascend::AclTensorCache x_out_cache_; + + float alpha_storage_ = 1.0f; + + aclScalar* alpha_ = nullptr; + + std::vector rstd_shape_; + void* rstd_data_ = nullptr; + + aclTensor* rstd_tensor_ = nullptr; + + mutable aclOpExecutor* add_exec_ = nullptr; + + mutable uint64_t add_ws_ = 0; + + mutable aclOpExecutor* norm_exec_ = nullptr; + + mutable uint64_t norm_ws_ = 0; }; } // namespace infini::ops diff --git a/src/ascend/add_rms_norm/kernel_fused.h b/src/ascend/add_rms_norm/kernel_fused.h new file mode 100644 index 0000000..2959a73 --- /dev/null +++ b/src/ascend/add_rms_norm/kernel_fused.h @@ -0,0 +1,124 @@ +#ifndef INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_FUSED_H_ +#define INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_FUSED_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_add_rms_norm.h" +#include "ascend/add_rms_norm/registry.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "operator.h" + +namespace infini::ops { + +// Fused implementation via aclnnAddRmsNorm (implementation index 1). +// +// Computes x_out = x1 + x2 and y_out = rms_norm(x_out, gamma, eps) in a +// single CANN launch. The fused API has higher host-side launch overhead +// (~200 us) compared to the decomposed aclnnAdd + aclnnRmsNorm path (~39 us), +// but may offer better NPU-side efficiency for large tensors where kernel +// fusion reduces memory traffic. +// +// Select via `implementation_index=1` in Python: +// infini.ops.add_rms_norm(..., implementation_index=1, stream=s) +template <> +class Operator : public AddRmsNorm { + public: + Operator(const Tensor x1, const Tensor x2, const Tensor gamma, float eps, + Tensor y_out, Tensor x_out) + : AddRmsNorm(x1, x2, gamma, eps, y_out, x_out), + x1_cache_(x1), + x2_cache_(x2), + gamma_cache_(gamma), + y_out_cache_(y_out), + x_out_cache_(x_out) { + // aclnnAddRmsNorm requires rstdOut to have the same ndim as x1, with + // the last gamma.ndim() dimensions set to 1. For example: + // x1 shape(2, 32, 128), gamma shape(128) -> rstdOut shape(2, 32, 1) + // x1 shape(64, 128), gamma shape(128) -> rstdOut shape(64, 1) + fused_rstd_shape_.reserve(ndim_); + for (size_t i = 0; i < ndim_ - gamma.ndim(); ++i) { + fused_rstd_shape_.push_back(static_cast(x1.size(i))); + } + for (size_t i = 0; i < gamma.ndim(); ++i) { + fused_rstd_shape_.push_back(1); + } + + size_t rstd_elems = 1; + for (auto d : fused_rstd_shape_) { + rstd_elems *= static_cast(d); + } + size_t rstd_bytes = rstd_elems * sizeof(float); + aclrtMalloc(&rstd_data_, rstd_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + rstd_tensor_ = aclCreateTensor( + fused_rstd_shape_.data(), + static_cast(fused_rstd_shape_.size()), ACL_FLOAT, + /*strides=*/nullptr, 0, ACL_FORMAT_ND, fused_rstd_shape_.data(), + static_cast(fused_rstd_shape_.size()), rstd_data_); + } + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + if (rstd_tensor_) aclDestroyTensor(rstd_tensor_); + if (rstd_data_) aclrtFree(rstd_data_); + } + + void operator()(const Tensor x1, const Tensor x2, const Tensor gamma, + float eps, Tensor y_out, Tensor x_out) const override { + auto t_x1 = x1_cache_.get(const_cast(x1.data())); + auto t_x2 = x2_cache_.get(const_cast(x2.data())); + auto t_gamma = gamma_cache_.get(const_cast(gamma.data())); + auto t_y_out = y_out_cache_.get(y_out.data()); + auto t_x_out = x_out_cache_.get(x_out.data()); + auto stream = static_cast(stream_); + + if (!executor_) { + aclnnAddRmsNormGetWorkspaceSize(t_x1, t_x2, t_gamma, + static_cast(eps), t_y_out, + rstd_tensor_, t_x_out, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_x1, + const_cast(x1.data())); + aclSetInputTensorAddr(executor_, 1, t_x2, + const_cast(x2.data())); + aclSetInputTensorAddr(executor_, 2, t_gamma, + const_cast(gamma.data())); + aclSetOutputTensorAddr(executor_, 0, t_y_out, y_out.data()); + // rstd at output index 1 has a stable address — no update needed. + aclSetOutputTensorAddr(executor_, 2, t_x_out, x_out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnAddRmsNorm(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache x1_cache_; + + mutable ascend::AclTensorCache x2_cache_; + + mutable ascend::AclTensorCache gamma_cache_; + + mutable ascend::AclTensorCache y_out_cache_; + + mutable ascend::AclTensorCache x_out_cache_; + + std::vector fused_rstd_shape_; + + void* rstd_data_ = nullptr; + + aclTensor* rstd_tensor_ = nullptr; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/add_rms_norm/registry.h b/src/ascend/add_rms_norm/registry.h new file mode 100644 index 0000000..d48de30 --- /dev/null +++ b/src/ascend/add_rms_norm/registry.h @@ -0,0 +1,15 @@ +#ifndef INFINI_OPS_ASCEND_ADD_RMS_NORM_REGISTRY_H_ +#define INFINI_OPS_ASCEND_ADD_RMS_NORM_REGISTRY_H_ + +#include "base/add_rms_norm.h" + +namespace infini::ops { + +template <> +struct ActiveImplementationsImpl { + using type = List<0, 1>; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/cast/kernel.h b/src/ascend/cast/kernel.h new file mode 100644 index 0000000..645f05a --- /dev/null +++ b/src/ascend/cast/kernel.h @@ -0,0 +1,60 @@ +#ifndef INFINI_OPS_ASCEND_CAST_KERNEL_H_ +#define INFINI_OPS_ASCEND_CAST_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_cast.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/cast.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Cast { + public: + Operator(const Tensor input, Tensor out) + : Cast(input, out), + in_cache_(input), + out_cache_(out), + acl_out_dtype_(ascend::toAclDtype(out.dtype())) {} + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + } + + void operator()(const Tensor input, Tensor out) const override { + auto stream = static_cast(stream_); + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + aclnnCastGetWorkspaceSize(t_in, acl_out_dtype_, t_out, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_in, + const_cast(input.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnCast(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache out_cache_; + + aclDataType acl_out_dtype_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/cat/kernel.h b/src/ascend/cat/kernel.h new file mode 100644 index 0000000..a847b92 --- /dev/null +++ b/src/ascend/cat/kernel.h @@ -0,0 +1,91 @@ +#ifndef INFINI_OPS_ASCEND_CAT_KERNEL_H_ +#define INFINI_OPS_ASCEND_CAT_KERNEL_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn/acl_meta.h" +#include "aclnnop/aclnn_cat.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/cat.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Cat { + public: + Operator(const Tensor first_input, std::vector rest_inputs, + int64_t dim, Tensor out) + : Cat(first_input, rest_inputs, dim, out), out_cache_(out) { + // Build AclTensorCache for each input tensor. + in_caches_.reserve(input_count_); + in_caches_.emplace_back(first_input); + for (const auto& t : rest_inputs) { + in_caches_.emplace_back(t); + } + } + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + if (tensor_list_) aclDestroyTensorList(tensor_list_); + } + + void operator()(const Tensor first_input, std::vector rest_inputs, + int64_t dim, Tensor out) const override { + auto stream = static_cast(stream_); + + // Collect all input tensors in order. + std::vector inputs; + inputs.reserve(input_count_); + inputs.push_back(&first_input); + for (const auto& t : rest_inputs) { + inputs.push_back(&t); + } + + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + // First call: create descriptors, tensor list, and executor. + std::vector acl_tensors(input_count_); + for (size_t i = 0; i < input_count_; ++i) { + acl_tensors[i] = + in_caches_[i].get(const_cast(inputs[i]->data())); + } + + tensor_list_ = aclCreateTensorList( + const_cast(acl_tensors.data()), + static_cast(input_count_)); + + aclnnCatGetWorkspaceSize(tensor_list_, dim_, t_out, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + // Subsequent calls: update data pointers on cached descriptors. + for (size_t i = 0; i < input_count_; ++i) { + in_caches_[i].get(const_cast(inputs[i]->data())); + } + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnCat(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable std::vector in_caches_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclTensorList* tensor_list_ = nullptr; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/causal_softmax/kernel.h b/src/ascend/causal_softmax/kernel.h index 5883c42..a27cb5d 100644 --- a/src/ascend/causal_softmax/kernel.h +++ b/src/ascend/causal_softmax/kernel.h @@ -28,7 +28,10 @@ namespace infini::ops { template <> class Operator : public CausalSoftmax { public: - Operator(const Tensor input, Tensor out) : CausalSoftmax(input, out) { + Operator(const Tensor input, Tensor out) + : CausalSoftmax(input, out), + in_cache_(input), + out_cache_(out) { // Contiguous temp buffer with the same element count as input. size_t n_elems = input.numel(); size_t elem_bytes = kDataTypeToSize.at(dtype_); @@ -36,6 +39,7 @@ class Operator : public CausalSoftmax { // Build a contiguous Tensor descriptor pointing to temp_buf_. Tensor temp_t{temp_buf_, input.shape(), input.dtype(), input.device()}; + temp_cache_ = ascend::AclTensorCache(temp_t); // Causal mask: mask[i][j] = 1 when position j must be masked for query i. // Shape (seq_len, total_seq_len) – broadcasts over the batch dimension. @@ -69,6 +73,9 @@ class Operator : public CausalSoftmax { } ~Operator() { + if (copy_exec_) aclDestroyAclOpExecutor(copy_exec_); + if (fill_exec_) aclDestroyAclOpExecutor(fill_exec_); + if (softmax_exec_) aclDestroyAclOpExecutor(softmax_exec_); aclrtFree(temp_buf_); aclrtFree(mask_buf_); aclDestroyTensor(mask_tensor_); @@ -76,50 +83,74 @@ class Operator : public CausalSoftmax { } void operator()(const Tensor input, Tensor out) const override { - Tensor temp_t{temp_buf_, input.shape(), input.dtype(), input.device()}; - auto t_in = ascend::buildAclTensor(input); - auto t_temp = ascend::buildAclTensor(temp_t); - auto t_out = ascend::buildAclTensor(out); + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_temp = temp_cache_.get(temp_buf_); + auto t_out = out_cache_.get(out.data()); auto stream = static_cast(stream_); - uint64_t ws_needed = 0; - aclOpExecutor* exec = nullptr; - // Step 1: copy input (possibly non-contiguous) into contiguous temp. - aclnnInplaceCopyGetWorkspaceSize(t_temp, t_in, &ws_needed, &exec); - auto& copy_arena = ascend::workspacePool().ensure(stream, ws_needed); - uint64_t copy_ws = ws_needed; - aclnnInplaceCopy(copy_arena.buf, copy_ws, exec, stream); + if (!copy_exec_) { + aclnnInplaceCopyGetWorkspaceSize(t_temp, t_in, ©_ws_, ©_exec_); + aclSetAclOpExecutorRepeatable(copy_exec_); + } else { + aclSetInputTensorAddr(copy_exec_, 0, t_temp, temp_buf_); + aclSetInputTensorAddr(copy_exec_, 1, t_in, + const_cast(input.data())); + } + auto& copy_arena = ascend::workspacePool().ensure(stream, copy_ws_); + aclnnInplaceCopy(copy_arena.buf, copy_ws_, copy_exec_, stream); // Step 2: mask upper-triangle positions with -inf in-place. - ws_needed = 0; - exec = nullptr; - aclnnInplaceMaskedFillScalarGetWorkspaceSize(t_temp, mask_tensor_, neg_inf_, - &ws_needed, &exec); - auto& fill_arena = ascend::workspacePool().ensure(stream, ws_needed); - uint64_t fill_ws = ws_needed; - aclnnInplaceMaskedFillScalar(fill_arena.buf, fill_ws, exec, stream); + // mask_tensor_ and neg_inf_ have stable addresses — first-call only. + if (!fill_exec_) { + aclnnInplaceMaskedFillScalarGetWorkspaceSize( + t_temp, mask_tensor_, neg_inf_, &fill_ws_, &fill_exec_); + aclSetAclOpExecutorRepeatable(fill_exec_); + } + auto& fill_arena = ascend::workspacePool().ensure(stream, fill_ws_); + aclnnInplaceMaskedFillScalar(fill_arena.buf, fill_ws_, fill_exec_, stream); // Step 3: softmax over the last dimension → out. - ws_needed = 0; - exec = nullptr; - constexpr int64_t kLastDim = -1; - aclnnSoftmaxGetWorkspaceSize(t_temp, kLastDim, t_out, &ws_needed, &exec); - auto& softmax_arena = ascend::workspacePool().ensure(stream, ws_needed); - uint64_t softmax_ws = ws_needed; - aclnnSoftmax(softmax_arena.buf, softmax_ws, exec, stream); - - aclDestroyTensor(t_in); - aclDestroyTensor(t_temp); - aclDestroyTensor(t_out); + if (!softmax_exec_) { + constexpr int64_t kLastDim = -1; + aclnnSoftmaxGetWorkspaceSize(t_temp, kLastDim, t_out, &softmax_ws_, + &softmax_exec_); + aclSetAclOpExecutorRepeatable(softmax_exec_); + } else { + aclSetOutputTensorAddr(softmax_exec_, 0, t_out, out.data()); + } + auto& softmax_arena = ascend::workspacePool().ensure(stream, softmax_ws_); + aclnnSoftmax(softmax_arena.buf, softmax_ws_, softmax_exec_, stream); } private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable ascend::AclTensorCache temp_cache_; + float neg_inf_storage_ = -std::numeric_limits::infinity(); + void* temp_buf_ = nullptr; + void* mask_buf_ = nullptr; + aclTensor* mask_tensor_ = nullptr; + aclScalar* neg_inf_ = nullptr; + + mutable aclOpExecutor* copy_exec_ = nullptr; + + mutable uint64_t copy_ws_ = 0; + + mutable aclOpExecutor* fill_exec_ = nullptr; + + mutable uint64_t fill_ws_ = 0; + + mutable aclOpExecutor* softmax_exec_ = nullptr; + + mutable uint64_t softmax_ws_ = 0; }; } // namespace infini::ops diff --git a/src/ascend/common.h b/src/ascend/common.h index f5ecb1a..639a635 100644 --- a/src/ascend/common.h +++ b/src/ascend/common.h @@ -51,6 +51,125 @@ inline aclTensor* buildAclTensor(const Tensor& t, static_cast(storage_shape.size()), const_cast(t.data())); } +// Pre-computed tensor metadata for descriptor reuse. +// +// Stores shape, strides, storage_shape, and dtype once (avoiding per-call heap +// allocations). The aclTensor descriptor is created on the first `get()` call +// and its data pointer is updated in-place via `aclSetRawTensorAddr` on +// subsequent calls. +class AclTensorCache { + public: + AclTensorCache() = default; + + // Construct from explicit metadata (for device buffers not wrapped in Tensor). + // Computes contiguous strides from shape. + AclTensorCache(std::vector shape, aclDataType dtype, void* data) + : shape_(std::move(shape)), dtype_(dtype) { + strides_.resize(shape_.size()); + int64_t stride = 1; + for (int i = static_cast(shape_.size()) - 1; i >= 0; --i) { + strides_[i] = stride; + stride *= shape_[i]; + } + storage_shape_ = {stride}; + + if (data) { + tensor_ = aclCreateTensor( + shape_.data(), static_cast(shape_.size()), dtype_, + strides_.data(), + /*storageOffset=*/0, ACL_FORMAT_ND, storage_shape_.data(), + static_cast(storage_shape_.size()), data); + } + } + + explicit AclTensorCache(const Tensor& t, bool transpose_last2 = false) + : dtype_{toAclDtype(t.dtype())} { + shape_.assign(t.shape().begin(), t.shape().end()); + strides_.assign(t.strides().begin(), t.strides().end()); + + if (transpose_last2 && shape_.size() >= 2) { + auto n = shape_.size(); + std::swap(shape_[n - 2], shape_[n - 1]); + std::swap(strides_[n - 2], strides_[n - 1]); + } + + int64_t storage_elems = 1; + for (size_t i = 0; i < shape_.size(); ++i) { + if (shape_[i] == 0) { + storage_elems = 0; + break; + } + if (strides_[i] > 0 && shape_[i] > 1) { + storage_elems += static_cast(shape_[i] - 1) * strides_[i]; + } + } + storage_shape_ = {storage_elems}; + } + + ~AclTensorCache() { + if (tensor_) { + aclDestroyTensor(tensor_); + } + } + + AclTensorCache(const AclTensorCache&) = delete; + + AclTensorCache& operator=(const AclTensorCache&) = delete; + + AclTensorCache(AclTensorCache&& o) noexcept + : shape_(std::move(o.shape_)), + strides_(std::move(o.strides_)), + storage_shape_(std::move(o.storage_shape_)), + dtype_(o.dtype_), + tensor_(o.tensor_) { + o.tensor_ = nullptr; + } + + AclTensorCache& operator=(AclTensorCache&& o) noexcept { + if (this != &o) { + if (tensor_) { + aclDestroyTensor(tensor_); + } + shape_ = std::move(o.shape_); + strides_ = std::move(o.strides_); + storage_shape_ = std::move(o.storage_shape_); + dtype_ = o.dtype_; + tensor_ = o.tensor_; + o.tensor_ = nullptr; + } + + return *this; + } + + // Update the data pointer and return the cached descriptor. + aclTensor* get(void* data) const { + if (tensor_) { + aclSetRawTensorAddr(tensor_, data); + + return tensor_; + } + + tensor_ = aclCreateTensor( + shape_.data(), static_cast(shape_.size()), dtype_, + strides_.data(), + /*storageOffset=*/0, ACL_FORMAT_ND, storage_shape_.data(), + static_cast(storage_shape_.size()), data); + + return tensor_; + } + + private: + std::vector shape_; + + std::vector strides_; + + std::vector storage_shape_; + + aclDataType dtype_{ACL_DT_UNDEFINED}; + + mutable aclTensor* tensor_ = nullptr; +}; + } // namespace infini::ops::ascend #endif diff --git a/src/ascend/flash_attention/kernel.h b/src/ascend/flash_attention/kernel.h index 3b82e53..3dae947 100644 --- a/src/ascend/flash_attention/kernel.h +++ b/src/ascend/flash_attention/kernel.h @@ -43,18 +43,32 @@ inline aclTensor* reshapeView(const Tensor& t, // Extract cu_seqlens differences to a host aclIntArray. // cu_seqlens = [0, s1, s1+s2, ...] -> per_seq_lens = [s1, s2, ...]. // Used by paged decode (actualSeqLengthsKv = per-sequence KV lengths). +// +// When cu_seqlens is a CPU tensor (device type kCpu), the data pointer is +// already on the host and can be read directly — no D2H sync needed. inline aclIntArray* extractSeqLengths(const Tensor& cu_seqlens, aclrtStream stream) { auto n = cu_seqlens.numel(); - std::vector cu_host(n); - aclrtMemcpyAsync(cu_host.data(), n * sizeof(int64_t), cu_seqlens.data(), - n * sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST, stream); - aclrtSynchronizeStream(stream); + + const int64_t* cu_host_ptr = nullptr; + std::vector cu_host_buf; + + if (cu_seqlens.device().type() == Device::Type::kCpu) { + cu_host_ptr = static_cast(cu_seqlens.data()); + } else { + cu_host_buf.resize(n); + aclrtMemcpyAsync(cu_host_buf.data(), n * sizeof(int64_t), + cu_seqlens.data(), n * sizeof(int64_t), + ACL_MEMCPY_DEVICE_TO_HOST, stream); + aclrtSynchronizeStream(stream); + cu_host_ptr = cu_host_buf.data(); + } std::vector lengths(n - 1); for (size_t i = 0; i < lengths.size(); ++i) { - lengths[i] = cu_host[i + 1] - cu_host[i]; + lengths[i] = cu_host_ptr[i + 1] - cu_host_ptr[i]; } + return aclCreateIntArray(lengths.data(), static_cast(lengths.size())); } @@ -63,16 +77,28 @@ inline aclIntArray* extractSeqLengths(const Tensor& cu_seqlens, // cu_seqlens = [0, s1, s1+s2, ...] -> cum_lens = [s1, s1+s2, ...]. // FIA V4 TND varlen uses cumulative end positions, matching the vllm-ascend // convention for npu_fused_infer_attention_score actual_seq_lengths. +// +// When cu_seqlens is a CPU tensor, reads directly from host memory. inline aclIntArray* cumSeqLengths(const Tensor& cu_seqlens, aclrtStream stream) { auto n = cu_seqlens.numel(); - std::vector cu_host(n); - aclrtMemcpyAsync(cu_host.data(), n * sizeof(int64_t), cu_seqlens.data(), - n * sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST, stream); - aclrtSynchronizeStream(stream); + + const int64_t* cu_host_ptr = nullptr; + std::vector cu_host_buf; + + if (cu_seqlens.device().type() == Device::Type::kCpu) { + cu_host_ptr = static_cast(cu_seqlens.data()); + } else { + cu_host_buf.resize(n); + aclrtMemcpyAsync(cu_host_buf.data(), n * sizeof(int64_t), + cu_seqlens.data(), n * sizeof(int64_t), + ACL_MEMCPY_DEVICE_TO_HOST, stream); + aclrtSynchronizeStream(stream); + cu_host_ptr = cu_host_buf.data(); + } // Skip the leading 0; return [s1, s1+s2, ...]. - return aclCreateIntArray(cu_host.data() + 1, static_cast(n - 1)); + return aclCreateIntArray(cu_host_ptr + 1, static_cast(n - 1)); } // Allocate a 2048x2048 lower-triangular UINT8 causal mask on device. @@ -107,7 +133,58 @@ inline aclTensor* makeCausalMask(void** mask_buf, aclrtStream stream) { template <> class Operator : public FlashAttention { public: - using FlashAttention::FlashAttention; + Operator(const Tensor query, const Tensor key, const Tensor value, + std::optional cu_seqlens_q, + std::optional cu_seqlens_kv, + std::optional block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, bool causal, + int64_t window_left, int64_t window_right, int64_t block_size, + Tensor output) + : FlashAttention(query, key, value, cu_seqlens_q, cu_seqlens_kv, + block_table, num_heads, num_kv_heads, head_size, scale, + causal, window_left, window_right, block_size, output) { + paged_ = block_table.has_value() && block_size > 0; + aclDataType acl_dt = ascend::toAclDtype(query.dtype()); + + if (!paged_) { + // Prefill: cache Q and output (TND layout). + prefill_q_cache_ = ascend::AclTensorCache(query); + prefill_out_cache_ = ascend::AclTensorCache(output); + + // Pre-compute causal mask once (sparse_mode >= 2). + if (causal) { + int64_t sm = (window_left >= 0) ? 4 : 3; + if (sm >= 2) { + causal_mask_ = detail::makeCausalMask(&causal_mask_buf_, nullptr); + } + } + } else { + // Decode: cache Q/output (BNSD), block_table. + const int64_t N = query.size(1); + const int64_t D = query.size(2); + const int64_t B = query.size(0); + + decode_q_cache_ = ascend::AclTensorCache( + {B, N, 1, D}, acl_dt, const_cast(query.data())); + decode_out_cache_ = ascend::AclTensorCache( + {B, N, 1, D}, acl_dt, output.data()); + block_table_cache_ = ascend::AclTensorCache(block_table.value()); + + // Pre-compute KV reshape metadata. + const int64_t nb = key.size(0); + const int64_t bsz = key.size(1); + const int64_t NkvD = key.size(2) * key.size(3); + kv_shape_ = {nb, bsz, NkvD}; + kv_strides_ = {bsz * NkvD, NkvD, 1}; + kv_storage_shape_ = {nb * bsz * NkvD}; + kv_acl_dt_ = acl_dt; + } + } + + ~Operator() { + if (causal_mask_) aclDestroyTensor(causal_mask_); + if (causal_mask_buf_) aclrtFree(causal_mask_buf_); + } void operator()(const Tensor query, const Tensor key, const Tensor value, std::optional cu_seqlens_q, @@ -117,28 +194,18 @@ class Operator : public FlashAttention { bool causal, int64_t window_left, int64_t window_right, int64_t block_size, Tensor output) const override { auto stream = static_cast(stream_); - const bool paged = block_table.has_value() && block_size > 0; - - // Map causal + window_left/right to FIA sparse_mode / preTokens / - // nextTokens. - // - // causal=true, window_left<0 -> sparse_mode=3 (full causal) - // causal=true, window_left>=0 -> sparse_mode=4 (sliding - // window causal) causal=false -> sparse_mode=0 - // (no mask) - // - // sparse_mode is ignored by FIA when Q_S=1 (paged decode); effective_sparse - // is set to 0 in that path to avoid allocating the unnecessary causal mask. + const bool paged = paged_; + int64_t sparse_mode; int64_t pre_tokens = 2147483647; int64_t next_tokens = 2147483647; if (causal) { if (window_left >= 0) { - sparse_mode = 4; // band: sliding window causal + sparse_mode = 4; pre_tokens = window_left; next_tokens = 0; } else { - sparse_mode = 3; // rightDownCausal: full causal, pre/next ignored + sparse_mode = 3; next_tokens = 0; } } else { @@ -148,14 +215,11 @@ class Operator : public FlashAttention { } if (!paged) { - // --- Prefill (single- or multi-sequence) --- - // V4 TND: query/key/value passed as token-packed [T, N, D]; per-sequence - // lengths are derived from cu_seqlens. Single fused call for all - // sequences, equivalent to flash_attn_varlen_func on CUDA. + // --- Prefill --- int64_t T = query.size(0); - // V4 TND varlen uses cumulative end positions [s1, s1+s2, ...]. - // For single-seq (no cu_seqlens), [T] is both per-seq and cumulative. + // cumSeqLengths / extractSeqLengths automatically skip D2H when + // cu_seqlens is a CPU tensor (see detail:: helpers above). aclIntArray* seq_q = cu_seqlens_q.has_value() ? detail::cumSeqLengths(cu_seqlens_q.value(), stream) @@ -165,44 +229,24 @@ class Operator : public FlashAttention { ? detail::cumSeqLengths(cu_seqlens_kv.value(), stream) : aclCreateIntArray(&T, 1); - aclTensor* t_q = ascend::buildAclTensor(query); + aclTensor* t_q = prefill_q_cache_.get(const_cast(query.data())); + // K/V descriptors go into TensorList which takes ownership — must be + // per-call (cannot cache). aclTensor* t_k = ascend::buildAclTensor(key); aclTensor* t_v = ascend::buildAclTensor(value); - aclTensor* t_out = ascend::buildAclTensor(output); + aclTensor* t_out = prefill_out_cache_.get(output.data()); const aclTensor* k_arr[] = {t_k}; const aclTensor* v_arr[] = {t_v}; aclTensorList* key_list = aclCreateTensorList(k_arr, 1); aclTensorList* val_list = aclCreateTensorList(v_arr, 1); - // sparseMode 2/3/4 require a 2048x2048 lower-triangular causal mask. - aclTensor* atten_mask = nullptr; - void* mask_buf = nullptr; - if (sparse_mode >= 2) { - atten_mask = detail::makeCausalMask(&mask_buf, stream); - } - uint64_t ws_needed = 0; aclOpExecutor* executor = nullptr; - // Parameter order: query, key, value, - // pseShift, attenMask, actualSeqLengths, actualSeqLengthsKv, - // deqScale1, quantScale1, deqScale2, quantScale2, quantOffset2, - // antiquantScale, antiquantOffset, - // blockTable, queryPaddingSize, kvPaddingSize, - // keyAntiquantScale, keyAntiquantOffset, - // valueAntiquantScale, valueAntiquantOffset, - // keySharedPrefix, valueSharedPrefix, actualSharedPrefixLen, - // queryRope, keyRope, keyRopeAntiquantScale, - // dequantScaleQuery, learnableSink, - // numHeads, scaleValue, preTokens, nextTokens, inputLayout, - // numKeyValueHeads, sparseMode, innerPrecise, blockSize, - // antiquantMode, softmaxLseFlag, - // keyAntiquantMode, valueAntiquantMode, queryQuantMode, - // attentionOut, softmaxLse, workspaceSize, executor aclError gws = aclnnFusedInferAttentionScoreV4GetWorkspaceSize( t_q, key_list, val_list, - nullptr, // pseShift - atten_mask, // attenMask + nullptr, // pseShift + causal_mask_, // attenMask (pre-computed, or nullptr) seq_q, // actualSeqLengths seq_kv, // actualSeqLengthsKv nullptr, nullptr, nullptr, nullptr, @@ -234,44 +278,40 @@ class Operator : public FlashAttention { assert(ret == ACL_SUCCESS && "aclnnFusedInferAttentionScoreV4 failed (prefill)"); - aclDestroyTensor(t_q); - aclDestroyTensor(t_out); + // t_q and t_out are owned by caches — do NOT destroy. + // t_k and t_v are owned by TensorLists. aclDestroyTensorList(key_list); aclDestroyTensorList(val_list); aclDestroyIntArray(seq_q); aclDestroyIntArray(seq_kv); - if (atten_mask) aclDestroyTensor(atten_mask); - if (mask_buf) aclrtFree(mask_buf); return; } // --- Paged decode --- - // V4 BNSD: reshape query/output [B, N, D] -> [B, N, 1, D]. - // KV cache [num_blocks, block_size, N_kv, D] flattened to - // [num_blocks, block_size, N_kv*D] (zero-copy, FIA BSH kv format). assert(cu_seqlens_kv.has_value() && "`FlashAttention` paged decode requires `cu_seqlens_kv`"); - const int64_t N = query.size(1); - const int64_t D = query.size(2); - const int64_t B = query.size(0); - const int64_t nb = key.size(0); - const int64_t bsz = key.size(1); - const int64_t NkvD = key.size(2) * key.size(3); - - std::vector bnsd_sh = {B, N, 1, D}; - std::vector bnsd_st = {N * D, D, D, 1}; - aclTensor* t_query = detail::reshapeView(query, bnsd_sh, bnsd_st); - aclTensor* t_output = detail::reshapeView(output, bnsd_sh, bnsd_st); - - std::vector kv_sh = {nb, bsz, NkvD}; - std::vector kv_st = {bsz * NkvD, NkvD, 1}; - aclTensor* t_key = detail::reshapeView(key, kv_sh, kv_st); - aclTensor* t_value = detail::reshapeView(value, kv_sh, kv_st); - + aclTensor* t_query = decode_q_cache_.get(const_cast(query.data())); + aclTensor* t_output = decode_out_cache_.get(output.data()); + + // K/V descriptors go into TensorList which takes ownership — must be + // per-call. Use pre-computed metadata to avoid heap allocs. + aclTensor* t_key = aclCreateTensor( + kv_shape_.data(), static_cast(kv_shape_.size()), kv_acl_dt_, + kv_strides_.data(), 0, ACL_FORMAT_ND, kv_storage_shape_.data(), + static_cast(kv_storage_shape_.size()), + const_cast(key.data())); + aclTensor* t_value = aclCreateTensor( + kv_shape_.data(), static_cast(kv_shape_.size()), kv_acl_dt_, + kv_strides_.data(), 0, ACL_FORMAT_ND, kv_storage_shape_.data(), + static_cast(kv_storage_shape_.size()), + const_cast(value.data())); + + // extractSeqLengths skips D2H when cu_seqlens_kv is a CPU tensor. aclIntArray* seq_kv = detail::extractSeqLengths(cu_seqlens_kv.value(), stream); - aclTensor* t_block_table = ascend::buildAclTensor(block_table.value()); + aclTensor* t_block_table = + block_table_cache_.get(const_cast(block_table.value().data())); const aclTensor* k_arr[] = {t_key}; const aclTensor* v_arr[] = {t_value}; @@ -307,13 +347,37 @@ class Operator : public FlashAttention { assert(ret == ACL_SUCCESS && "aclnnFusedInferAttentionScoreV4 failed (decode)"); - aclDestroyTensor(t_query); - aclDestroyTensor(t_output); + // t_query, t_output, t_block_table owned by caches — do NOT destroy. + // t_key, t_value owned by TensorLists. aclDestroyTensorList(key_list); aclDestroyTensorList(val_list); - aclDestroyTensor(t_block_table); aclDestroyIntArray(seq_kv); } + + private: + bool paged_ = false; + + mutable ascend::AclTensorCache prefill_q_cache_; + + mutable ascend::AclTensorCache prefill_out_cache_; + + mutable ascend::AclTensorCache decode_q_cache_; + + mutable ascend::AclTensorCache decode_out_cache_; + + mutable ascend::AclTensorCache block_table_cache_; + + aclTensor* causal_mask_ = nullptr; + + void* causal_mask_buf_ = nullptr; + + std::vector kv_shape_; + + std::vector kv_strides_; + + std::vector kv_storage_shape_; + + aclDataType kv_acl_dt_ = ACL_DT_UNDEFINED; }; } // namespace infini::ops diff --git a/src/ascend/gemm/kernel.h b/src/ascend/gemm/kernel.h index ceed55a..a59d624 100644 --- a/src/ascend/gemm/kernel.h +++ b/src/ascend/gemm/kernel.h @@ -21,12 +21,17 @@ class Operator : public Gemm { : Gemm(a, b, alpha, beta, trans_a, trans_b, c), batched_{batch_count_ > 1}, alpha_val_{alpha.value_or(1.0f)}, - beta_val_{beta.value_or(1.0f)} { + beta_val_{beta.value_or(1.0f)}, + self_cache_(c), + a_cache_(a, trans_a_), + b_cache_(b, trans_b_), + out_cache_(c) { alpha_scalar_ = aclCreateScalar(&alpha_val_, ACL_FLOAT); beta_scalar_ = aclCreateScalar(&beta_val_, ACL_FLOAT); } ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); aclDestroyScalar(alpha_scalar_); aclDestroyScalar(beta_scalar_); } @@ -36,43 +41,60 @@ class Operator : public Gemm { std::optional trans_b, Tensor c) const override { auto stream = static_cast(stream_); - auto t_self = ascend::buildAclTensor(c); - auto t_a = ascend::buildAclTensor(a, trans_a_); - auto t_b = ascend::buildAclTensor(b, trans_b_); - auto t_out = ascend::buildAclTensor(c); + auto t_self = self_cache_.get(c.data()); + auto t_a = a_cache_.get(const_cast(a.data())); + auto t_b = b_cache_.get(const_cast(b.data())); + auto t_out = out_cache_.get(c.data()); - uint64_t ws_needed = 0; - aclOpExecutor* executor = nullptr; - - if (batched_) { - aclnnBaddbmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, - alpha_scalar_, t_out, 0, &ws_needed, - &executor); + if (!executor_) { + if (batched_) { + aclnnBaddbmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_size_, + &executor_); + } else { + aclnnAddmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_size_, + &executor_); + } + aclSetAclOpExecutorRepeatable(executor_); } else { - aclnnAddmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, alpha_scalar_, - t_out, 0, &ws_needed, &executor); + aclSetInputTensorAddr(executor_, 0, t_self, c.data()); + aclSetInputTensorAddr(executor_, 1, t_a, const_cast(a.data())); + aclSetInputTensorAddr(executor_, 2, t_b, const_cast(b.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, c.data()); } - auto& arena = ascend::workspacePool().ensure(stream, ws_needed); + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); if (batched_) { - aclnnBaddbmm(arena.buf, ws_needed, executor, stream); + aclnnBaddbmm(arena.buf, ws_size_, executor_, stream); } else { - aclnnAddmm(arena.buf, ws_needed, executor, stream); + aclnnAddmm(arena.buf, ws_size_, executor_, stream); } - - aclDestroyTensor(t_self); - aclDestroyTensor(t_a); - aclDestroyTensor(t_b); - aclDestroyTensor(t_out); } private: bool batched_; + float alpha_val_; + float beta_val_; + aclScalar* alpha_scalar_ = nullptr; + aclScalar* beta_scalar_ = nullptr; + + mutable ascend::AclTensorCache self_cache_; + + mutable ascend::AclTensorCache a_cache_; + + mutable ascend::AclTensorCache b_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; }; } // namespace infini::ops diff --git a/src/ascend/linear/kernel.h b/src/ascend/linear/kernel.h new file mode 100644 index 0000000..ec0f4ec --- /dev/null +++ b/src/ascend/linear/kernel.h @@ -0,0 +1,122 @@ +#ifndef INFINI_OPS_ASCEND_LINEAR_KERNEL_H_ +#define INFINI_OPS_ASCEND_LINEAR_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_addmm.h" +#include "aclnnop/aclnn_baddbmm.h" +#include "aclnnop/aclnn_matmul.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/linear.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Linear { + public: + Operator(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) + : Linear(a, b, bias, trans_a, trans_b, out), + batched_{out.ndim() > 2}, + a_cache_(a, trans_a), + b_cache_(b, trans_b), + out_cache_(out) { + if (has_bias_) { + bias_cache_ = ascend::AclTensorCache(*bias); + alpha_scalar_ = aclCreateScalar(&alpha_storage_, ACL_FLOAT); + beta_scalar_ = aclCreateScalar(&beta_storage_, ACL_FLOAT); + } + } + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + if (alpha_scalar_) aclDestroyScalar(alpha_scalar_); + if (beta_scalar_) aclDestroyScalar(beta_scalar_); + } + + void operator()(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) const override { + auto stream = static_cast(stream_); + auto t_a = a_cache_.get(const_cast(a.data())); + auto t_b = b_cache_.get(const_cast(b.data())); + auto t_out = out_cache_.get(out.data()); + + if (has_bias_) { + auto t_bias = bias_cache_.get(const_cast(bias->data())); + + if (!executor_) { + if (batched_) { + aclnnBaddbmmGetWorkspaceSize(t_bias, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_size_, + &executor_); + } else { + aclnnAddmmGetWorkspaceSize(t_bias, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_size_, + &executor_); + } + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_bias, + const_cast(bias->data())); + aclSetInputTensorAddr(executor_, 1, t_a, + const_cast(a.data())); + aclSetInputTensorAddr(executor_, 2, t_b, + const_cast(b.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + + if (batched_) { + aclnnBaddbmm(arena.buf, ws_size_, executor_, stream); + } else { + aclnnAddmm(arena.buf, ws_size_, executor_, stream); + } + } else { + if (!executor_) { + int8_t cube_math_type = 1; + aclnnMatmulGetWorkspaceSize(t_a, t_b, t_out, cube_math_type, + &ws_size_, &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_a, + const_cast(a.data())); + aclSetInputTensorAddr(executor_, 1, t_b, + const_cast(b.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnMatmul(arena.buf, ws_size_, executor_, stream); + } + } + + private: + bool batched_; + + mutable ascend::AclTensorCache bias_cache_; + + mutable ascend::AclTensorCache a_cache_; + + mutable ascend::AclTensorCache b_cache_; + + mutable ascend::AclTensorCache out_cache_; + + float alpha_storage_ = 1.0f; + + float beta_storage_ = 1.0f; + + aclScalar* alpha_scalar_ = nullptr; + + aclScalar* beta_scalar_ = nullptr; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/matmul/kernel.h b/src/ascend/matmul/kernel.h index 4070634..2d98c23 100644 --- a/src/ascend/matmul/kernel.h +++ b/src/ascend/matmul/kernel.h @@ -15,28 +15,47 @@ template <> class Operator : public Matmul { public: Operator(const Tensor a, const Tensor b, Tensor c, bool trans_a, bool trans_b) - : Matmul(a, b, c, trans_a, trans_b) {} + : Matmul(a, b, c, trans_a, trans_b), + a_cache_(a, trans_a), + b_cache_(b, trans_b), + out_cache_(c) {} + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + } void operator()(const Tensor a, const Tensor b, Tensor c, bool trans_a, bool trans_b) const override { auto stream = static_cast(stream_); - auto t_a = ascend::buildAclTensor(a, trans_a); - auto t_b = ascend::buildAclTensor(b, trans_b); - auto t_out = ascend::buildAclTensor(c); - - uint64_t ws_needed = 0; - aclOpExecutor* executor = nullptr; - // cube_math_type = 1: allow fp16 accumulation. - int8_t cube_math_type = 1; - aclnnMatmulGetWorkspaceSize(t_a, t_b, t_out, cube_math_type, &ws_needed, - &executor); - auto& arena = ascend::workspacePool().ensure(stream, ws_needed); - aclnnMatmul(arena.buf, ws_needed, executor, stream); - - aclDestroyTensor(t_a); - aclDestroyTensor(t_b); - aclDestroyTensor(t_out); + auto t_a = a_cache_.get(const_cast(a.data())); + auto t_b = b_cache_.get(const_cast(b.data())); + auto t_out = out_cache_.get(c.data()); + + if (!executor_) { + int8_t cube_math_type = 1; + aclnnMatmulGetWorkspaceSize(t_a, t_b, t_out, cube_math_type, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_a, const_cast(a.data())); + aclSetInputTensorAddr(executor_, 1, t_b, const_cast(b.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, c.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnMatmul(arena.buf, ws_size_, executor_, stream); } + + private: + mutable ascend::AclTensorCache a_cache_; + + mutable ascend::AclTensorCache b_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; }; } // namespace infini::ops diff --git a/src/ascend/mul/kernel.h b/src/ascend/mul/kernel.h new file mode 100644 index 0000000..38a0986 --- /dev/null +++ b/src/ascend/mul/kernel.h @@ -0,0 +1,63 @@ +#ifndef INFINI_OPS_ASCEND_MUL_KERNEL_H_ +#define INFINI_OPS_ASCEND_MUL_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_mul.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/mul.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Mul { + public: + Operator(const Tensor input, const Tensor other, Tensor out) + : Mul(input, other, out), + in_cache_(input), + oth_cache_(other), + out_cache_(out) {} + + ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + } + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + auto stream = static_cast(stream_); + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_oth = oth_cache_.get(const_cast(other.data())); + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + aclnnMulGetWorkspaceSize(t_in, t_oth, t_out, &ws_size_, &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_in, + const_cast(input.data())); + aclSetInputTensorAddr(executor_, 1, t_oth, + const_cast(other.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnMul(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache oth_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/reshape_and_cache/kernel.h b/src/ascend/reshape_and_cache/kernel.h index 609a1ee..3bc0360 100644 --- a/src/ascend/reshape_and_cache/kernel.h +++ b/src/ascend/reshape_and_cache/kernel.h @@ -3,67 +3,106 @@ #include #include -#include #include "acl/acl.h" -#include "ascend/device_.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_index_copy.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" #include "base/reshape_and_cache.h" #include "operator.h" namespace infini::ops { +// Device-side scatter via aclnnInplaceIndexCopy. +// +// The previous implementation copied slot_mapping D2H (aclrtSynchronizeStream), +// then issued per-token D2D memcpy in a host loop. For batch=256, this meant +// ~100 us sync + ~500 us host loop overhead. aclnnInplaceIndexCopy performs +// the scatter entirely on the NPU with two ACLNN calls (one for K, one for V), +// eliminating all D2H synchronisation and host-side loops. +// +// Requirement: slot_mapping must contain only non-negative values. Padding +// tokens (slot < 0) must be filtered by the caller before invoking this +// operator. template <> class Operator : public ReshapeAndCache { public: - using ReshapeAndCache::ReshapeAndCache; + Operator(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, Tensor kv_cache_out) + : ReshapeAndCache(key, value, kv_cache, slot_mapping, kv_cache_out), + key_cache_(key), + value_cache_(value), + slot_cache_(slot_mapping) { + auto num_blocks = static_cast(kv_cache.size(1)); + auto bs = static_cast(block_size_); + int64_t total_slots = num_blocks * bs; + int64_t nkv = static_cast(num_kv_heads_); + int64_t hs = static_cast(head_size_); + + aclDataType acl_dt = ascend::toAclDtype(key.dtype()); + + // Flattened K cache view: [total_slots, num_kv_heads, head_size]. + // K cache is kv_cache_out[0], starting at offset 0. + kv_k_cache_ = ascend::AclTensorCache( + {total_slots, nkv, hs}, acl_dt, kv_cache_out.data()); + + // V cache is kv_cache_out[1], offset by stride(0) elements. + v_offset_bytes_ = static_cast(kv_cache_out.stride(0)) * + kv_cache_out.element_size(); + kv_v_cache_ = ascend::AclTensorCache( + {total_slots, nkv, hs}, acl_dt, + static_cast(kv_cache_out.data()) + v_offset_bytes_); + } void operator()(const Tensor key, const Tensor value, const Tensor kv_cache, const Tensor slot_mapping, Tensor kv_cache_out) const override { auto stream = static_cast(stream_); - // Copy slot_mapping to host for address computation. - auto num_tokens = static_cast(num_tokens_); - std::vector slots(num_tokens); - aclrtMemcpyAsync(slots.data(), num_tokens * sizeof(int64_t), - slot_mapping.data(), num_tokens * sizeof(int64_t), - ACL_MEMCPY_DEVICE_TO_HOST, stream); - aclrtSynchronizeStream(stream); + void* kv_k_data = kv_cache_out.data(); + void* kv_v_data = + static_cast(kv_cache_out.data()) + v_offset_bytes_; - auto bs = static_cast(block_size_); - auto row_bytes = static_cast(num_kv_heads_ * head_size_) * - kDataTypeToSize.at(key.dtype()); - - // kv_cache layout: [2, num_blocks, block_size, num_kv_heads, head_size] - // kv_cache[0] = key cache, kv_cache[1] = value cache. - // Stride for the first dim (K vs V): kv_cache.stride(0). - auto kv_stride0 = static_cast(kv_cache_out.stride(0)); - - for (int64_t i = 0; i < num_tokens; ++i) { - auto slot = slots[i]; - if (slot < 0) continue; // Padding token — skip. - auto block_idx = slot / bs; - auto offset = slot % bs; - - auto cache_offset = (block_idx * kv_cache_out.stride(1) + - offset * kv_cache_out.stride(2)) * - kv_cache_out.element_size(); - - auto* k_src = static_cast(key.data()) + - i * key.stride(0) * key.element_size(); - auto* k_dst = static_cast(kv_cache_out.data()) + cache_offset; - aclrtMemcpyAsync(k_dst, row_bytes, k_src, row_bytes, - ACL_MEMCPY_DEVICE_TO_DEVICE, stream); - - auto* v_src = static_cast(value.data()) + - i * value.stride(0) * value.element_size(); - auto* v_dst = static_cast(kv_cache_out.data()) + - kv_stride0 * kv_cache_out.element_size() + cache_offset; - aclrtMemcpyAsync(v_dst, row_bytes, v_src, row_bytes, - ACL_MEMCPY_DEVICE_TO_DEVICE, stream); - } + auto t_kv_k = kv_k_cache_.get(kv_k_data); + auto t_kv_v = kv_v_cache_.get(kv_v_data); + auto t_key = key_cache_.get(const_cast(key.data())); + auto t_value = value_cache_.get(const_cast(value.data())); + auto t_slot = slot_cache_.get(const_cast(slot_mapping.data())); + + // K cache scatter: kv_k[slot_mapping[i]] = key[i] along dim 0. + // Executor caching is not used here because aclnnInplaceIndexCopy is an + // inplace operation where self is both input and output; the executor + // reuse via aclSetInputTensorAddr does not update the output reference. + uint64_t k_ws = 0; + aclOpExecutor* k_exec = nullptr; + aclnnInplaceIndexCopyGetWorkspaceSize(t_kv_k, 0, t_slot, t_key, + &k_ws, &k_exec); + auto& k_arena = ascend::workspacePool().ensure(stream, k_ws); + aclnnInplaceIndexCopy(k_arena.buf, k_ws, k_exec, stream); + + // V cache scatter: kv_v[slot_mapping[i]] = value[i] along dim 0. + uint64_t v_ws = 0; + aclOpExecutor* v_exec = nullptr; + aclnnInplaceIndexCopyGetWorkspaceSize(t_kv_v, 0, t_slot, t_value, + &v_ws, &v_exec); + auto& v_arena = ascend::workspacePool().ensure(stream, v_ws); + aclnnInplaceIndexCopy(v_arena.buf, v_ws, v_exec, stream); } + + private: + mutable ascend::AclTensorCache kv_k_cache_; + + mutable ascend::AclTensorCache kv_v_cache_; + + mutable ascend::AclTensorCache key_cache_; + + mutable ascend::AclTensorCache value_cache_; + + mutable ascend::AclTensorCache slot_cache_; + + size_t v_offset_bytes_ = 0; }; } // namespace infini::ops diff --git a/src/ascend/rms_norm/kernel.h b/src/ascend/rms_norm/kernel.h index 9eef1bb..4061936 100644 --- a/src/ascend/rms_norm/kernel.h +++ b/src/ascend/rms_norm/kernel.h @@ -17,44 +17,69 @@ template <> class Operator : public RmsNorm { public: Operator(const Tensor input, const Tensor weight, float eps, Tensor out) - : RmsNorm(input, weight, eps, out) { + : RmsNorm(input, weight, eps, out), + in_cache_(input), + weight_cache_(weight), + out_cache_(out) { // aclnnRmsNorm writes rstd as a required side output. // Allocate a persistent device buffer for it. rstd_shape_ = {static_cast(batch_size_), static_cast(nhead_)}; size_t rstd_bytes = batch_size_ * nhead_ * sizeof(float); aclrtMalloc(&rstd_data_, rstd_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + // The rstd descriptor has a stable data pointer. + rstd_tensor_ = aclCreateTensor(rstd_shape_.data(), 2, ACL_FLOAT, + /*strides=*/nullptr, 0, ACL_FORMAT_ND, + rstd_shape_.data(), 2, rstd_data_); } ~Operator() { + if (executor_) aclDestroyAclOpExecutor(executor_); + if (rstd_tensor_) aclDestroyTensor(rstd_tensor_); if (rstd_data_) aclrtFree(rstd_data_); } void operator()(const Tensor input, const Tensor weight, float eps, Tensor out) const override { - auto t_in = ascend::buildAclTensor(input); - auto t_weight = ascend::buildAclTensor(weight); - auto t_out = ascend::buildAclTensor(out); - // rstd is always float32 regardless of input dtype. - auto t_rstd = aclCreateTensor(rstd_shape_.data(), 2, ACL_FLOAT, - /*strides=*/nullptr, 0, ACL_FORMAT_ND, - rstd_shape_.data(), 2, rstd_data_); - uint64_t ws_needed = 0; - aclOpExecutor* executor = nullptr; - aclnnRmsNormGetWorkspaceSize(t_in, t_weight, eps, t_out, t_rstd, &ws_needed, - &executor); + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_weight = weight_cache_.get(const_cast(weight.data())); + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + aclnnRmsNormGetWorkspaceSize(t_in, t_weight, eps, t_out, rstd_tensor_, + &ws_size_, &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_in, + const_cast(input.data())); + aclSetInputTensorAddr(executor_, 1, t_weight, + const_cast(weight.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + // rstd at output index 1 has a stable address — no update needed. + } + auto stream = static_cast(stream_); - auto& arena = ascend::workspacePool().ensure(stream, ws_needed); - aclnnRmsNorm(arena.buf, ws_needed, executor, stream); - aclDestroyTensor(t_in); - aclDestroyTensor(t_weight); - aclDestroyTensor(t_out); - aclDestroyTensor(t_rstd); + auto& arena = ascend::workspacePool().ensure(stream, ws_size_); + aclnnRmsNorm(arena.buf, ws_size_, executor_, stream); } private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache weight_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; + std::vector rstd_shape_; + void* rstd_data_ = nullptr; + + aclTensor* rstd_tensor_ = nullptr; }; } // namespace infini::ops diff --git a/src/ascend/rotary_embedding/kernel.h b/src/ascend/rotary_embedding/kernel.h index 5c3da01..659f91d 100644 --- a/src/ascend/rotary_embedding/kernel.h +++ b/src/ascend/rotary_embedding/kernel.h @@ -10,25 +10,28 @@ #include "aclnn/aclnn_base.h" #include "aclnnop/aclnn_apply_rotary_pos_emb_v2.h" #include "aclnnop/aclnn_index_select.h" -#include "aclnnop/aclnn_rotary_position_embedding.h" -#include "ascend/data_type_.h" +#include "ascend/common.h" #include "ascend/workspace_pool_.h" #include "base/rotary_embedding.h" #include "operator.h" namespace infini::ops { -// aclnnApplyRotaryPosEmbV2 hardware constraints on Atlas A2/A3: -// - rotaryMode "half" only (neox style) -// - D (last dim of queryRef) must be 64 or 128 -// - bfloat16 only (float16 accumulates with ~1 ULP error that exceeds -// atol=0.001 in tests; bfloat16 passes with atol=0.005) +// Rotary position embedding via aclnnApplyRotaryPosEmbV2. // -// Use V2 when all three hold; fall back to V1 otherwise. -static bool use_rope_v2(int64_t D, bool is_neox, DataType dtype) { - return is_neox && (D == 64 || D == 128) && dtype == DataType::kBFloat16; -} - +// V2 handles Q and K simultaneously in a single inplace call (layout=4, TND). +// The `rotaryMode` parameter accepts "half", "interleave", or "quarter", but +// CANN currently only supports "half" (neox style). Passing "interleave" or +// "quarter" returns ACLNN_ERR_PARAM_INVALID. +// +// fp16 note: V2 accumulates with ~4 ULP error for float16 (max diff ~0.008), +// which exceeds strict atol=0.001 tests but is acceptable for inference. +// bfloat16 passes with atol=0.005. +// +// Restrictions: +// - rotary_dim must equal head_size (partial rotation not supported). +// - is_neox_style must be true (rotaryMode="half" only). +// All mainstream models (LLaMA, Qwen, Mistral, DeepSeek) satisfy both. template <> class Operator : public RotaryEmbedding { @@ -38,118 +41,105 @@ class Operator bool is_neox_style, Tensor query_out, Tensor key_out) : RotaryEmbedding(positions, query, key, cos_sin_cache, head_size, rotary_dim, is_neox_style, query_out, key_out) { + assert(rotary_dim == head_size && + "Ascend `RotaryEmbedding` requires rotary_dim == head_size " + "(partial rotation not supported)"); + assert(is_neox_style && + "Ascend `RotaryEmbedding` requires neox style — " + "aclnnApplyRotaryPosEmbV2 rotaryMode only supports \"half\"; " + "\"interleave\" and \"quarter\" return ACLNN_ERR_PARAM_INVALID"); + const int64_t max_seq_len = cos_sin_cache.size(0); - const int64_t R = rotary_dim_; - const int64_t half_R = R / 2; - cache_elem_size_ = cos_sin_cache.element_size(); - - // Copy raw cache to host for pre-expansion (one-time cost). - size_t raw_bytes = static_cast(max_seq_len * R) * cache_elem_size_; - std::vector cache_host(raw_bytes); - aclrtMemcpy(cache_host.data(), raw_bytes, cos_sin_cache.data(), raw_bytes, - ACL_MEMCPY_DEVICE_TO_HOST); - - // Pre-expand into separate cos/sin tables with duplicated values. - // After expansion each row is R-wide: - // neox: cos = [c0..c_{hR-1}, c0..c_{hR-1}] (first half repeated) - // interleave: cos = [c0,c0, c1,c1, ..., c_{hR-1},c_{hR-1}] - // Same pattern for sin. - table_bytes_ = raw_bytes; - std::vector cos_table_host(table_bytes_); - std::vector sin_table_host(table_bytes_); + const int64_t D = head_size_; + const int64_t half_D = D / 2; + const size_t elem_sz = cos_sin_cache.element_size(); + + // One-time: D2H copy cos_sin_cache, split cos/sin, expand, upload. + // cos_sin_cache layout per row: [c0..c_{D/2-1}, s0..s_{D/2-1}]. + size_t table_bytes = static_cast(max_seq_len * D) * elem_sz; + std::vector cache_host(table_bytes); + aclrtMemcpy(cache_host.data(), table_bytes, cos_sin_cache.data(), + table_bytes, ACL_MEMCPY_DEVICE_TO_HOST); + + // Pre-expand into separate cos/sin tables [max_seq_len, D]. + // neox: cos = [c0..c_{hD-1}, c0..c_{hD-1}] (halves duplicated) + // interleave: cos = [c0,c0, c1,c1, ..., c_{hD-1},c_{hD-1}] + std::vector cos_host(table_bytes); + std::vector sin_host(table_bytes); for (int64_t p = 0; p < max_seq_len; ++p) { - if (is_neox_style_) { - for (int64_t j = 0; j < half_R; ++j) { - const uint8_t* c_src = - cache_host.data() + - static_cast(p * R + j) * cache_elem_size_; - const uint8_t* s_src = - cache_host.data() + - static_cast(p * R + half_R + j) * cache_elem_size_; - auto* cos_dst = cos_table_host.data(); - auto* sin_dst = sin_table_host.data(); - std::memcpy( - cos_dst + static_cast(p * R + j) * cache_elem_size_, - c_src, cache_elem_size_); - std::memcpy(cos_dst + static_cast(p * R + half_R + j) * - cache_elem_size_, - c_src, cache_elem_size_); - std::memcpy( - sin_dst + static_cast(p * R + j) * cache_elem_size_, - s_src, cache_elem_size_); - std::memcpy(sin_dst + static_cast(p * R + half_R + j) * - cache_elem_size_, - s_src, cache_elem_size_); - } - } else { - for (int64_t j = 0; j < half_R; ++j) { - const uint8_t* c_src = - cache_host.data() + - static_cast(p * R + j) * cache_elem_size_; - const uint8_t* s_src = - cache_host.data() + - static_cast(p * R + half_R + j) * cache_elem_size_; - auto* cos_dst = cos_table_host.data(); - auto* sin_dst = sin_table_host.data(); - std::memcpy( - cos_dst + static_cast(p * R + 2 * j) * cache_elem_size_, - c_src, cache_elem_size_); - std::memcpy(cos_dst + static_cast(p * R + 2 * j + 1) * - cache_elem_size_, - c_src, cache_elem_size_); - std::memcpy( - sin_dst + static_cast(p * R + 2 * j) * cache_elem_size_, - s_src, cache_elem_size_); - std::memcpy(sin_dst + static_cast(p * R + 2 * j + 1) * - cache_elem_size_, - s_src, cache_elem_size_); - } + for (int64_t j = 0; j < half_D; ++j) { + const auto* c_src = + cache_host.data() + + static_cast(p * D + j) * elem_sz; + const auto* s_src = + cache_host.data() + + static_cast(p * D + half_D + j) * elem_sz; + + // Neox expansion: [c0..c_{hD-1}, c0..c_{hD-1}] (halves duplicated). + std::memcpy( + cos_host.data() + static_cast(p * D + j) * elem_sz, + c_src, elem_sz); + std::memcpy( + cos_host.data() + + static_cast(p * D + half_D + j) * elem_sz, + c_src, elem_sz); + std::memcpy( + sin_host.data() + static_cast(p * D + j) * elem_sz, + s_src, elem_sz); + std::memcpy( + sin_host.data() + + static_cast(p * D + half_D + j) * elem_sz, + s_src, elem_sz); } } // Upload expanded tables to device (one-time). - aclrtMalloc(&cos_table_dev_, table_bytes_, ACL_MEM_MALLOC_NORMAL_ONLY); - aclrtMalloc(&sin_table_dev_, table_bytes_, ACL_MEM_MALLOC_NORMAL_ONLY); - aclrtMemcpy(cos_table_dev_, table_bytes_, cos_table_host.data(), - table_bytes_, ACL_MEMCPY_HOST_TO_DEVICE); - aclrtMemcpy(sin_table_dev_, table_bytes_, sin_table_host.data(), - table_bytes_, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMalloc(&cos_table_dev_, table_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMalloc(&sin_table_dev_, table_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMemcpy(cos_table_dev_, table_bytes, cos_host.data(), table_bytes, + ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(sin_table_dev_, table_bytes, sin_host.data(), table_bytes, + ACL_MEMCPY_HOST_TO_DEVICE); const int64_t T = num_tokens_; const int64_t Nq = num_heads_; const int64_t Nkv = num_kv_heads_; - const int64_t D = head_size_; - const bool v2 = use_rope_v2(R, is_neox_style_, query.dtype()); - use_v2_ = v2; - - // Gathered output buffers [T, R] — filled by aclnnIndexSelect at runtime. - gathered_cs_bytes_ = static_cast(T * R) * cache_elem_size_; - aclrtMalloc(&cos_dev_, gathered_cs_bytes_, ACL_MEM_MALLOC_NORMAL_ONLY); - aclrtMalloc(&sin_dev_, gathered_cs_bytes_, ACL_MEM_MALLOC_NORMAL_ONLY); - - // Scratch for partial-rotation (R < D) — used by both V1 and V2. - if (R < D) { - size_t q_rot_bytes = static_cast(T * Nq * R) * cache_elem_size_; - size_t k_rot_bytes = static_cast(T * Nkv * R) * cache_elem_size_; - aclrtMalloc(&q_rot_dev_, q_rot_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); - aclrtMalloc(&k_rot_dev_, k_rot_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); - if (!v2) { - aclrtMalloc(&q_out_rot_dev_, q_rot_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); - aclrtMalloc(&k_out_rot_dev_, k_rot_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); - } - } + aclDataType acl_dt = ascend::toAclDtype(query.dtype()); + + // Gathered cos/sin buffers [T, D] — filled by aclnnIndexSelect each call. + size_t gathered_bytes = static_cast(T * D) * elem_sz; + aclrtMalloc(&cos_dev_, gathered_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMalloc(&sin_dev_, gathered_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + // IndexSelect descriptors: table ptrs stable, positions ptr varies. + cos_table_cache_ = ascend::AclTensorCache( + {max_seq_len, D}, acl_dt, cos_table_dev_); + sin_table_cache_ = ascend::AclTensorCache( + {max_seq_len, D}, acl_dt, sin_table_dev_); + idx_cache_ = ascend::AclTensorCache( + {T}, ACL_INT64, const_cast(positions.data())); + cos_out_cache_ = ascend::AclTensorCache({T, D}, acl_dt, cos_dev_); + sin_out_cache_ = ascend::AclTensorCache({T, D}, acl_dt, sin_dev_); + + // V2 descriptors: cos/sin [T, 1, D], Q [T, Nq, D], K [T, Nkv, D]. + cos_v2_cache_ = ascend::AclTensorCache({T, 1, D}, acl_dt, cos_dev_); + sin_v2_cache_ = ascend::AclTensorCache({T, 1, D}, acl_dt, sin_dev_); + q_cache_ = ascend::AclTensorCache( + {T, Nq, D}, acl_dt, const_cast(query_out.data())); + k_cache_ = ascend::AclTensorCache( + {T, Nkv, D}, acl_dt, const_cast(key_out.data())); } ~Operator() { + if (idx_cos_exec_) aclDestroyAclOpExecutor(idx_cos_exec_); + if (idx_sin_exec_) aclDestroyAclOpExecutor(idx_sin_exec_); + if (v2_exec_) aclDestroyAclOpExecutor(v2_exec_); + if (cos_table_dev_) aclrtFree(cos_table_dev_); if (sin_table_dev_) aclrtFree(sin_table_dev_); if (cos_dev_) aclrtFree(cos_dev_); if (sin_dev_) aclrtFree(sin_dev_); - if (q_rot_dev_) aclrtFree(q_rot_dev_); - if (k_rot_dev_) aclrtFree(k_rot_dev_); - if (q_out_rot_dev_) aclrtFree(q_out_rot_dev_); - if (k_out_rot_dev_) aclrtFree(k_out_rot_dev_); } void operator()(const Tensor positions, const Tensor query, const Tensor key, @@ -162,342 +152,120 @@ class Operator const int64_t Nq = query.size(1); const int64_t Nkv = key.size(1); const int64_t D = head_size; - const int64_t R = rotary_dim; - const int64_t max_seq_len = cos_sin_cache.size(0); - - assert(R <= D); - assert(cos_sin_cache.size(1) == R); - // 1. Gather cos/sin on device via aclnnIndexSelect — fully async. - // No host sync, no D2H copy. Positions stay on device. + // Step 1: Gather cos/sin by positions via aclnnIndexSelect (async). { - aclDataType acl_dt_cs = ascend::toAclDtype(query.dtype()); - - // Table tensors: [max_seq_len, R] - std::vector table_shape = {max_seq_len, R}; - std::vector table_strides = {R, 1}; - std::vector table_storage = {max_seq_len * R}; - - aclTensor* t_cos_table = aclCreateTensor( - table_shape.data(), 2, acl_dt_cs, table_strides.data(), 0, - ACL_FORMAT_ND, table_storage.data(), 1, cos_table_dev_); - aclTensor* t_sin_table = aclCreateTensor( - table_shape.data(), 2, acl_dt_cs, table_strides.data(), 0, - ACL_FORMAT_ND, table_storage.data(), 1, sin_table_dev_); - - // Index tensor: positions [T], int64 — stays on device. - std::vector idx_shape = {T}; - std::vector idx_strides = {1}; - std::vector idx_storage = {T}; - aclTensor* t_idx = aclCreateTensor( - idx_shape.data(), 1, ACL_INT64, idx_strides.data(), 0, ACL_FORMAT_ND, - idx_storage.data(), 1, const_cast(positions.data())); - - // Output tensors: [T, R] - std::vector out_shape = {T, R}; - std::vector out_strides = {R, 1}; - std::vector out_storage = {T * R}; - - aclTensor* t_cos_out = - aclCreateTensor(out_shape.data(), 2, acl_dt_cs, out_strides.data(), 0, - ACL_FORMAT_ND, out_storage.data(), 1, cos_dev_); - aclTensor* t_sin_out = - aclCreateTensor(out_shape.data(), 2, acl_dt_cs, out_strides.data(), 0, - ACL_FORMAT_ND, out_storage.data(), 1, sin_dev_); - - // Get workspace sizes and executors for both gathers. - uint64_t ws_cos = 0, ws_sin = 0; - aclOpExecutor *exec_cos = nullptr, *exec_sin = nullptr; - aclnnIndexSelectGetWorkspaceSize(t_cos_table, 0, t_idx, t_cos_out, - &ws_cos, &exec_cos); - aclnnIndexSelectGetWorkspaceSize(t_sin_table, 0, t_idx, t_sin_out, - &ws_sin, &exec_sin); - - // Single workspace buffer large enough for both calls. - uint64_t ws_max = ws_cos > ws_sin ? ws_cos : ws_sin; + auto t_cos_table = cos_table_cache_.get(cos_table_dev_); + auto t_sin_table = sin_table_cache_.get(sin_table_dev_); + auto t_idx = idx_cache_.get(const_cast(positions.data())); + auto t_cos_out = cos_out_cache_.get(cos_dev_); + auto t_sin_out = sin_out_cache_.get(sin_dev_); + + if (!idx_cos_exec_) { + aclnnIndexSelectGetWorkspaceSize(t_cos_table, 0, t_idx, t_cos_out, + &idx_cos_ws_, &idx_cos_exec_); + aclSetAclOpExecutorRepeatable(idx_cos_exec_); + } else { + aclSetInputTensorAddr(idx_cos_exec_, 1, t_idx, + const_cast(positions.data())); + } + + if (!idx_sin_exec_) { + aclnnIndexSelectGetWorkspaceSize(t_sin_table, 0, t_idx, t_sin_out, + &idx_sin_ws_, &idx_sin_exec_); + aclSetAclOpExecutorRepeatable(idx_sin_exec_); + } else { + aclSetInputTensorAddr(idx_sin_exec_, 1, t_idx, + const_cast(positions.data())); + } + + uint64_t ws_max = idx_cos_ws_ > idx_sin_ws_ ? idx_cos_ws_ : idx_sin_ws_; auto& arena = ascend::workspacePool().ensure(stream, ws_max); - aclnnIndexSelect(arena.buf, ws_cos, exec_cos, stream); - aclnnIndexSelect(arena.buf, ws_sin, exec_sin, stream); + aclnnIndexSelect(arena.buf, idx_cos_ws_, idx_cos_exec_, stream); + aclnnIndexSelect(arena.buf, idx_sin_ws_, idx_sin_exec_, stream); + } - aclDestroyTensor(t_cos_table); - aclDestroyTensor(t_sin_table); - aclDestroyTensor(t_idx); - aclDestroyTensor(t_cos_out); - aclDestroyTensor(t_sin_out); + // Step 2: Copy q→q_out, k→k_out if not inplace (V2 operates inplace). + size_t elem_sz = query.element_size(); + + if (query.data() != query_out.data()) { + aclrtMemcpyAsync(query_out.data(), + static_cast(T * Nq * D) * elem_sz, query.data(), + static_cast(T * Nq * D) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); } - aclDataType acl_dt = ascend::toAclDtype(query.dtype()); + if (key.data() != key_out.data()) { + aclrtMemcpyAsync(key_out.data(), + static_cast(T * Nkv * D) * elem_sz, key.data(), + static_cast(T * Nkv * D) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } - if (use_v2_) { - // V2: fused Q+K, in-place, layout=4 (T-first 3D), "half" mode. - // cos/sin shape: [T, 1, R]. - std::vector cs_shape = {T, 1, R}; - std::vector cs_strides = {R, R, 1}; - std::vector cs_storage = {T * R}; - aclTensor* t_cos = - aclCreateTensor(cs_shape.data(), 3, acl_dt, cs_strides.data(), 0, - ACL_FORMAT_ND, cs_storage.data(), 1, cos_dev_); - aclTensor* t_sin = - aclCreateTensor(cs_shape.data(), 3, acl_dt, cs_strides.data(), 0, - ACL_FORMAT_ND, cs_storage.data(), 1, sin_dev_); - - int64_t layout = 4; - if (R == D) { - apply_rope_v2_full(query, key, query_out, key_out, T, Nq, Nkv, D, - acl_dt, t_cos, t_sin, layout, stream); - } else { - apply_rope_v2_partial(query, key, query_out, key_out, T, Nq, Nkv, D, R, - acl_dt, t_cos, t_sin, layout, stream); - } - aclDestroyTensor(t_cos); - aclDestroyTensor(t_sin); + // Step 3: Apply V2 RoPE inplace on q_out and k_out. + auto t_cos = cos_v2_cache_.get(cos_dev_); + auto t_sin = sin_v2_cache_.get(sin_dev_); + auto t_q = q_cache_.get(query_out.data()); + auto t_k = k_cache_.get(key_out.data()); + + if (!v2_exec_) { + aclnnApplyRotaryPosEmbV2GetWorkspaceSize( + t_q, t_k, t_cos, t_sin, /*layout=*/4, const_cast("half"), + &v2_ws_, &v2_exec_); + aclSetAclOpExecutorRepeatable(v2_exec_); } else { - // V1: separate Q and K calls, non-in-place, [1,T,1,R] cos/sin. - std::vector cs_shape = {1, T, 1, R}; - std::vector cs_strides = {T * R, R, R, 1}; - std::vector cs_storage = {T * R}; - aclTensor* t_cos = - aclCreateTensor(cs_shape.data(), 4, acl_dt, cs_strides.data(), 0, - ACL_FORMAT_ND, cs_storage.data(), 1, cos_dev_); - aclTensor* t_sin = - aclCreateTensor(cs_shape.data(), 4, acl_dt, cs_strides.data(), 0, - ACL_FORMAT_ND, cs_storage.data(), 1, sin_dev_); - - int64_t mode = is_neox_style ? 0 : 1; - apply_rope_v1(query, query_out, T, Nq, D, R, mode, t_cos, t_sin, - q_rot_dev_, q_out_rot_dev_, stream); - apply_rope_v1(key, key_out, T, Nkv, D, R, mode, t_cos, t_sin, k_rot_dev_, - k_out_rot_dev_, stream); - - aclDestroyTensor(t_cos); - aclDestroyTensor(t_sin); + aclSetInputTensorAddr(v2_exec_, 0, t_q, query_out.data()); + aclSetInputTensorAddr(v2_exec_, 1, t_k, key_out.data()); } + + auto& arena = ascend::workspacePool().ensure(stream, v2_ws_); + aclnnApplyRotaryPosEmbV2(arena.buf, v2_ws_, v2_exec_, stream); } private: - size_t cache_elem_size_ = 1; - - // Pre-expanded cos/sin tables on device: [max_seq_len, R]. - // Built once in the constructor with neox/interleave duplication. + // Pre-expanded cos/sin tables on device: [max_seq_len, D]. void* cos_table_dev_ = nullptr; - void* sin_table_dev_ = nullptr; - size_t table_bytes_ = 0; - // true when V2 hardware constraints are met (neox, D∈{64,128}, bf16). - bool use_v2_ = false; + void* sin_table_dev_ = nullptr; - // Device buffers for gathered [T, R] cos/sin (shared by V1 and V2). + // Device buffers for gathered [T, D] cos/sin. void* cos_dev_ = nullptr; + void* sin_dev_ = nullptr; - size_t gathered_cs_bytes_ = 0; - - // Scratch for partial rotation (R < D). - void* q_rot_dev_ = nullptr; - void* k_rot_dev_ = nullptr; - void* q_out_rot_dev_ = nullptr; - void* k_out_rot_dev_ = nullptr; - - // --- V2 helpers (neox bf16, D∈{64,128}) --- - - void apply_rope_v2_full(const Tensor& q, const Tensor& k, Tensor& q_out, - Tensor& k_out, int64_t T, int64_t Nq, int64_t Nkv, - int64_t D, aclDataType acl_dt, aclTensor* t_cos, - aclTensor* t_sin, int64_t layout, - aclrtStream stream) const { - size_t elem_sz = q.element_size(); - if (q.data() != q_out.data()) { - aclrtMemcpyAsync(const_cast(q_out.data()), - static_cast(T * Nq * D) * elem_sz, q.data(), - static_cast(T * Nq * D) * elem_sz, - ACL_MEMCPY_DEVICE_TO_DEVICE, stream); - } - if (k.data() != k_out.data()) { - size_t k_elem_sz = k.element_size(); - aclrtMemcpyAsync(const_cast(k_out.data()), - static_cast(T * Nkv * D) * k_elem_sz, k.data(), - static_cast(T * Nkv * D) * k_elem_sz, - ACL_MEMCPY_DEVICE_TO_DEVICE, stream); - } - std::vector q_shape = {T, Nq, D}; - std::vector q_strides = {Nq * D, D, 1}; - std::vector q_storage = {T * Nq * D}; - std::vector k_shape = {T, Nkv, D}; - std::vector k_strides = {Nkv * D, D, 1}; - std::vector k_storage = {T * Nkv * D}; - aclTensor* t_q = aclCreateTensor( - q_shape.data(), 3, acl_dt, q_strides.data(), 0, ACL_FORMAT_ND, - q_storage.data(), 1, const_cast(q_out.data())); - aclTensor* t_k = aclCreateTensor( - k_shape.data(), 3, acl_dt, k_strides.data(), 0, ACL_FORMAT_ND, - k_storage.data(), 1, const_cast(k_out.data())); - uint64_t ws = 0; - aclOpExecutor* exec = nullptr; - aclnnApplyRotaryPosEmbV2GetWorkspaceSize( - t_q, t_k, t_cos, t_sin, layout, const_cast("half"), &ws, &exec); - auto& arena = ascend::workspacePool().ensure(stream, ws); - aclnnApplyRotaryPosEmbV2(arena.buf, ws, exec, stream); - aclDestroyTensor(t_q); - aclDestroyTensor(t_k); - } - void apply_rope_v2_partial(const Tensor& q, const Tensor& k, Tensor& q_out, - Tensor& k_out, int64_t T, int64_t Nq, int64_t Nkv, - int64_t D, int64_t R, aclDataType acl_dt, - aclTensor* t_cos, aclTensor* t_sin, int64_t layout, - aclrtStream stream) const { - size_t elem_sz = q.element_size(); - size_t k_elem_sz = k.element_size(); - const int64_t pass = D - R; - - for (int64_t i = 0; i < T * Nq; ++i) { - aclrtMemcpyAsync(static_cast(q_rot_dev_) + - static_cast(i * R) * elem_sz, - static_cast(R) * elem_sz, - static_cast(q.data()) + - static_cast(i * D) * elem_sz, - static_cast(R) * elem_sz, - ACL_MEMCPY_DEVICE_TO_DEVICE, stream); - } - for (int64_t i = 0; i < T * Nkv; ++i) { - aclrtMemcpyAsync(static_cast(k_rot_dev_) + - static_cast(i * R) * k_elem_sz, - static_cast(R) * k_elem_sz, - static_cast(k.data()) + - static_cast(i * D) * k_elem_sz, - static_cast(R) * k_elem_sz, - ACL_MEMCPY_DEVICE_TO_DEVICE, stream); - } - std::vector qr_shape = {T, Nq, R}; - std::vector qr_strides = {Nq * R, R, 1}; - std::vector qr_storage = {T * Nq * R}; - std::vector kr_shape = {T, Nkv, R}; - std::vector kr_strides = {Nkv * R, R, 1}; - std::vector kr_storage = {T * Nkv * R}; - aclTensor* t_q_rot = - aclCreateTensor(qr_shape.data(), 3, acl_dt, qr_strides.data(), 0, - ACL_FORMAT_ND, qr_storage.data(), 1, q_rot_dev_); - aclTensor* t_k_rot = - aclCreateTensor(kr_shape.data(), 3, acl_dt, kr_strides.data(), 0, - ACL_FORMAT_ND, kr_storage.data(), 1, k_rot_dev_); - uint64_t ws = 0; - aclOpExecutor* exec = nullptr; - aclnnApplyRotaryPosEmbV2GetWorkspaceSize(t_q_rot, t_k_rot, t_cos, t_sin, - layout, const_cast("half"), - &ws, &exec); - auto& arena = ascend::workspacePool().ensure(stream, ws); - aclnnApplyRotaryPosEmbV2(arena.buf, ws, exec, stream); - aclDestroyTensor(t_q_rot); - aclDestroyTensor(t_k_rot); - - for (int64_t i = 0; i < T * Nq; ++i) { - aclrtMemcpyAsync(static_cast(const_cast(q_out.data())) + - static_cast(i * D) * elem_sz, - static_cast(R) * elem_sz, - static_cast(q_rot_dev_) + - static_cast(i * R) * elem_sz, - static_cast(R) * elem_sz, - ACL_MEMCPY_DEVICE_TO_DEVICE, stream); - aclrtMemcpyAsync(static_cast(const_cast(q_out.data())) + - static_cast(i * D + R) * elem_sz, - static_cast(pass) * elem_sz, - static_cast(q.data()) + - static_cast(i * D + R) * elem_sz, - static_cast(pass) * elem_sz, - ACL_MEMCPY_DEVICE_TO_DEVICE, stream); - } - for (int64_t i = 0; i < T * Nkv; ++i) { - aclrtMemcpyAsync(static_cast(const_cast(k_out.data())) + - static_cast(i * D) * k_elem_sz, - static_cast(R) * k_elem_sz, - static_cast(k_rot_dev_) + - static_cast(i * R) * k_elem_sz, - static_cast(R) * k_elem_sz, - ACL_MEMCPY_DEVICE_TO_DEVICE, stream); - aclrtMemcpyAsync(static_cast(const_cast(k_out.data())) + - static_cast(i * D + R) * k_elem_sz, - static_cast(pass) * k_elem_sz, - static_cast(k.data()) + - static_cast(i * D + R) * k_elem_sz, - static_cast(pass) * k_elem_sz, - ACL_MEMCPY_DEVICE_TO_DEVICE, stream); - } - } + // IndexSelect descriptors. + mutable ascend::AclTensorCache cos_table_cache_; - // --- V1 helper (fallback for non-neox, fp16, or D not in {64,128}) --- - - void apply_rope_v1(const Tensor& x, Tensor& out, int64_t T, int64_t N, - int64_t D, int64_t R, int64_t mode, aclTensor* t_cos, - aclTensor* t_sin, void* x_rot_dev, void* out_rot_dev, - aclrtStream stream) const { - aclDataType acl_dt = ascend::toAclDtype(x.dtype()); - size_t elem_sz = x.element_size(); - - if (R < D) { - for (int64_t i = 0; i < T * N; ++i) { - aclrtMemcpyAsync(static_cast(x_rot_dev) + - static_cast(i * R) * elem_sz, - static_cast(R) * elem_sz, - static_cast(x.data()) + - static_cast(i * D) * elem_sz, - static_cast(R) * elem_sz, - ACL_MEMCPY_DEVICE_TO_DEVICE, stream); - } - std::vector rot_sh = {1, T, N, R}; - std::vector rot_st = {T * N * R, N * R, R, 1}; - std::vector rot_storage = {T * N * R}; - aclTensor* t_x_rot = - aclCreateTensor(rot_sh.data(), 4, acl_dt, rot_st.data(), 0, - ACL_FORMAT_ND, rot_storage.data(), 1, x_rot_dev); - aclTensor* t_out_rot = - aclCreateTensor(rot_sh.data(), 4, acl_dt, rot_st.data(), 0, - ACL_FORMAT_ND, rot_storage.data(), 1, out_rot_dev); - uint64_t ws = 0; - aclOpExecutor* exec = nullptr; - aclnnRotaryPositionEmbeddingGetWorkspaceSize(t_x_rot, t_cos, t_sin, mode, - t_out_rot, &ws, &exec); - auto& arena = ascend::workspacePool().ensure(stream, ws); - aclnnRotaryPositionEmbedding(arena.buf, ws, exec, stream); - - const int64_t pass = D - R; - for (int64_t i = 0; i < T * N; ++i) { - aclrtMemcpyAsync(static_cast(const_cast(out.data())) + - static_cast(i * D) * elem_sz, - static_cast(R) * elem_sz, - static_cast(out_rot_dev) + - static_cast(i * R) * elem_sz, - static_cast(R) * elem_sz, - ACL_MEMCPY_DEVICE_TO_DEVICE, stream); - aclrtMemcpyAsync(static_cast(const_cast(out.data())) + - static_cast(i * D + R) * elem_sz, - static_cast(pass) * elem_sz, - static_cast(x.data()) + - static_cast(i * D + R) * elem_sz, - static_cast(pass) * elem_sz, - ACL_MEMCPY_DEVICE_TO_DEVICE, stream); - } - aclDestroyTensor(t_x_rot); - aclDestroyTensor(t_out_rot); - } else { - std::vector full_sh = {1, T, N, D}; - std::vector full_st = {T * N * D, N * D, D, 1}; - std::vector full_storage = {T * N * D}; - aclTensor* t_x = aclCreateTensor( - full_sh.data(), 4, acl_dt, full_st.data(), 0, ACL_FORMAT_ND, - full_storage.data(), 1, const_cast(x.data())); - aclTensor* t_out = aclCreateTensor( - full_sh.data(), 4, acl_dt, full_st.data(), 0, ACL_FORMAT_ND, - full_storage.data(), 1, const_cast(out.data())); - uint64_t ws = 0; - aclOpExecutor* exec = nullptr; - aclnnRotaryPositionEmbeddingGetWorkspaceSize(t_x, t_cos, t_sin, mode, - t_out, &ws, &exec); - auto& arena = ascend::workspacePool().ensure(stream, ws); - aclnnRotaryPositionEmbedding(arena.buf, ws, exec, stream); - aclDestroyTensor(t_x); - aclDestroyTensor(t_out); - } - } + mutable ascend::AclTensorCache sin_table_cache_; + + mutable ascend::AclTensorCache idx_cache_; + + mutable ascend::AclTensorCache cos_out_cache_; + + mutable ascend::AclTensorCache sin_out_cache_; + + // V2 descriptors. + mutable ascend::AclTensorCache cos_v2_cache_; + + mutable ascend::AclTensorCache sin_v2_cache_; + + mutable ascend::AclTensorCache q_cache_; + + mutable ascend::AclTensorCache k_cache_; + + // Cached executors. + mutable aclOpExecutor* idx_cos_exec_ = nullptr; + + mutable uint64_t idx_cos_ws_ = 0; + + mutable aclOpExecutor* idx_sin_exec_ = nullptr; + + mutable uint64_t idx_sin_ws_ = 0; + + mutable aclOpExecutor* v2_exec_ = nullptr; + + mutable uint64_t v2_ws_ = 0; }; } // namespace infini::ops diff --git a/src/ascend/swiglu/kernel.h b/src/ascend/swiglu/kernel.h index c7d31e7..b315989 100644 --- a/src/ascend/swiglu/kernel.h +++ b/src/ascend/swiglu/kernel.h @@ -22,47 +22,76 @@ template <> class Operator : public Swiglu { public: Operator(const Tensor input, const Tensor gate, Tensor out) - : Swiglu(input, gate, out) { + : Swiglu(input, gate, out), + in_cache_(input), + gate_cache_(gate), + out_cache_(out) { size_t nbytes = input.numel() * kDataTypeToSize.at(input.dtype()); aclrtMalloc(&temp_buf_, nbytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + // Build temp cache from gate geometry (contiguous, same shape/dtype). + Tensor temp_t{temp_buf_, gate.shape(), gate.dtype(), gate.device()}; + temp_cache_ = ascend::AclTensorCache(temp_t); } - ~Operator() { aclrtFree(temp_buf_); } + ~Operator() { + if (silu_exec_) aclDestroyAclOpExecutor(silu_exec_); + if (mul_exec_) aclDestroyAclOpExecutor(mul_exec_); + aclrtFree(temp_buf_); + } void operator()(const Tensor input, const Tensor gate, Tensor out) const override { - // temp_buf_ is a contiguous scratch buffer; give it contiguous strides. - Tensor temp_t{temp_buf_, gate.shape(), gate.dtype(), gate.device()}; - - auto t_in = ascend::buildAclTensor(input); - auto t_gate = ascend::buildAclTensor(gate); - auto t_out = ascend::buildAclTensor(out); - auto t_temp = ascend::buildAclTensor(temp_t); - - uint64_t ws_needed = 0; - aclOpExecutor* exec = nullptr; + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_gate = gate_cache_.get(const_cast(gate.data())); + auto t_out = out_cache_.get(out.data()); + auto t_temp = temp_cache_.get(temp_buf_); auto stream = static_cast(stream_); - // Step 1: silu(gate) -> temp. SwiGLU = input * silu(gate). - aclnnSiluGetWorkspaceSize(t_gate, t_temp, &ws_needed, &exec); - auto& silu_arena = ascend::workspacePool().ensure(stream, ws_needed); - aclnnSilu(silu_arena.buf, ws_needed, exec, stream); + // Step 1: silu(gate) -> temp. + if (!silu_exec_) { + aclnnSiluGetWorkspaceSize(t_gate, t_temp, &silu_ws_, &silu_exec_); + aclSetAclOpExecutorRepeatable(silu_exec_); + } else { + aclSetInputTensorAddr(silu_exec_, 0, t_gate, + const_cast(gate.data())); + aclSetOutputTensorAddr(silu_exec_, 0, t_temp, temp_buf_); + } + auto& silu_arena = ascend::workspacePool().ensure(stream, silu_ws_); + aclnnSilu(silu_arena.buf, silu_ws_, silu_exec_, stream); // Step 2: mul(input, temp) -> out. - uint64_t mul_ws = 0; - exec = nullptr; - aclnnMulGetWorkspaceSize(t_in, t_temp, t_out, &mul_ws, &exec); - auto& mul_arena = ascend::workspacePool().ensure(stream, mul_ws); - aclnnMul(mul_arena.buf, mul_ws, exec, stream); - - aclDestroyTensor(t_in); - aclDestroyTensor(t_gate); - aclDestroyTensor(t_out); - aclDestroyTensor(t_temp); + if (!mul_exec_) { + aclnnMulGetWorkspaceSize(t_in, t_temp, t_out, &mul_ws_, &mul_exec_); + aclSetAclOpExecutorRepeatable(mul_exec_); + } else { + aclSetInputTensorAddr(mul_exec_, 0, t_in, + const_cast(input.data())); + aclSetInputTensorAddr(mul_exec_, 1, t_temp, temp_buf_); + aclSetOutputTensorAddr(mul_exec_, 0, t_out, out.data()); + } + auto& mul_arena = ascend::workspacePool().ensure(stream, mul_ws_); + aclnnMul(mul_arena.buf, mul_ws_, mul_exec_, stream); } private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache gate_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable ascend::AclTensorCache temp_cache_; + void* temp_buf_ = nullptr; + + mutable aclOpExecutor* silu_exec_ = nullptr; + + mutable uint64_t silu_ws_ = 0; + + mutable aclOpExecutor* mul_exec_ = nullptr; + + mutable uint64_t mul_ws_ = 0; }; } // namespace infini::ops diff --git a/src/ascend/workspace_pool_.h b/src/ascend/workspace_pool_.h index bac2479..3d0f698 100644 --- a/src/ascend/workspace_pool_.h +++ b/src/ascend/workspace_pool_.h @@ -2,7 +2,9 @@ #define INFINI_OPS_ASCEND_WORKSPACE_POOL__H_ #include +#include #include +#include #include #include @@ -18,24 +20,57 @@ struct WorkspaceArena { class WorkspacePool { public: WorkspaceArena& ensure(aclrtStream stream, uint64_t needed) { + // Thread-local fast path: skip mutex when the same stream's arena already + // has enough capacity. After warmup (first call per operator), workspace + // sizes are fixed and this path is always taken. + // + // NOTE: Only the most recent stream is cached. If a single thread + // alternates between multiple streams (e.g. TP>1 driven by one thread), + // every stream switch falls back to the slow path. Replace with a + // small thread-local map if multi-stream-per-thread becomes common. + thread_local aclrtStream last_stream = nullptr; + thread_local WorkspaceArena* last_arena = nullptr; + + if (stream == last_stream && last_arena != nullptr && + needed <= last_arena->capacity) { + return *last_arena; + } + + // Slow path: look up arena in the map under lock. std::lock_guard lock(mutex_); auto& arena = arenas_[stream]; - if (needed <= arena.capacity) return arena; - if (arena.capacity > 0) { - aclrtSynchronizeStream(stream); - aclrtFree(arena.buf); - } - if (needed > 0) { - auto ret = aclrtMalloc(&arena.buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY); - assert(ret == ACL_SUCCESS && "`WorkspacePool`: `aclrtMalloc` failed"); + if (needed > arena.capacity) { + if (arena.capacity > 0) { + aclrtSynchronizeStream(stream); + aclrtFree(arena.buf); + } + if (needed > 0) { + auto ret = aclrtMalloc(&arena.buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY); + assert(ret == ACL_SUCCESS && "`WorkspacePool`: `aclrtMalloc` failed"); + } + arena.capacity = needed; } - arena.capacity = needed; + last_stream = stream; + last_arena = &arena; return arena; } ~WorkspacePool() { for (auto& [stream, arena] : arenas_) { - if (arena.capacity > 0) aclrtFree(arena.buf); + if (arena.capacity > 0) { + // The CANN runtime may already be torn down when this static + // destructor runs. aclrtGetDevice fails in that case — skip the + // free to avoid glibc "double free" abort. + int32_t dev_id = -1; + if (aclrtGetDevice(&dev_id) == ACL_SUCCESS) { + aclrtFree(arena.buf); + } else { + fprintf(stderr, + "[InfiniOps] `WorkspacePool`: CANN runtime already finalized, " + "skipping `aclrtFree` (%" PRIu64 " bytes leaked).\n", + arena.capacity); + } + } } } diff --git a/src/base/cast.h b/src/base/cast.h new file mode 100644 index 0000000..29f1f40 --- /dev/null +++ b/src/base/cast.h @@ -0,0 +1,52 @@ +#ifndef INFINI_OPS_BASE_CAST_H_ +#define INFINI_OPS_BASE_CAST_H_ + +#include "operator.h" + +namespace infini::ops { + +class Cast : public Operator { + public: + Cast(const Tensor input, Tensor out) + : ndim_{out.ndim()}, + output_size_{out.numel()}, + input_dtype_{input.dtype()}, + out_dtype_{out.dtype()}, + input_shape_{input.shape()}, + out_shape_{out.shape()}, + input_strides_{input.strides()}, + out_strides_{out.strides()}, + is_input_contiguous_{input.IsContiguous()}, + is_out_contiguous_{out.IsContiguous()} { + assert(input.numel() == out.numel() && + "the input and output of `Cast` must have the same number of " + "elements"); + } + + virtual void operator()(const Tensor input, Tensor out) const = 0; + + protected: + Tensor::Size ndim_{0}; + + Tensor::Size output_size_{0}; + + const DataType input_dtype_; + + const DataType out_dtype_; + + Tensor::Shape input_shape_; + + Tensor::Shape out_shape_; + + Tensor::Strides input_strides_; + + Tensor::Strides out_strides_; + + bool is_input_contiguous_{false}; + + bool is_out_contiguous_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/cat.h b/src/base/cat.h new file mode 100644 index 0000000..16f9bd2 --- /dev/null +++ b/src/base/cat.h @@ -0,0 +1,34 @@ +#ifndef INFINI_OPS_BASE_CAT_H_ +#define INFINI_OPS_BASE_CAT_H_ + +#include + +#include "operator.h" + +namespace infini::ops { + +class Cat : public Operator { + public: + Cat(const Tensor first_input, std::vector rest_inputs, int64_t dim, + Tensor out) + : dim_{dim}, input_count_{1 + rest_inputs.size()} { + assert(input_count_ >= 2 && "Cat requires at least 2 input tensors"); + + auto ndim = out.ndim(); + assert(dim >= 0 && dim < static_cast(ndim) && + "Cat dim out of range"); + } + + virtual void operator()(const Tensor first_input, + std::vector rest_inputs, int64_t dim, + Tensor out) const = 0; + + protected: + int64_t dim_; + + size_t input_count_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/linear.h b/src/base/linear.h new file mode 100644 index 0000000..520617f --- /dev/null +++ b/src/base/linear.h @@ -0,0 +1,64 @@ +#ifndef INFINI_OPS_BASE_LINEAR_H_ +#define INFINI_OPS_BASE_LINEAR_H_ + +#include + +#include "operator.h" + +namespace infini::ops { + +// Fused linear projection: out = a @ b (+ bias). +// +// When bias is present, computes out = a @ b + bias in a single dispatch. +// When bias is absent, computes out = a @ b (equivalent to Matmul). +// trans_a / trans_b: if true, transpose the last two dims before multiplying. +class Linear : public Operator { + public: + Linear(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) + : a_shape_{a.shape()}, + b_shape_{b.shape()}, + out_shape_{out.shape()}, + a_strides_{a.strides()}, + b_strides_{b.strides()}, + out_strides_{out.strides()}, + trans_a_{trans_a}, + trans_b_{trans_b}, + has_bias_{bias.has_value()} { + assert(a.dtype() == b.dtype() && + "operator `Linear` requires a and b to have the same dtype"); + assert(a.dtype() == out.dtype() && + "operator `Linear` requires a and out to have the same dtype"); + if (has_bias_) { + assert(bias->dtype() == out.dtype() && + "operator `Linear` requires bias and out to have the same dtype"); + } + } + + virtual void operator()(const Tensor a, const Tensor b, + std::optional bias, bool trans_a, + bool trans_b, Tensor out) const = 0; + + protected: + Tensor::Shape a_shape_; + + Tensor::Shape b_shape_; + + Tensor::Shape out_shape_; + + Tensor::Strides a_strides_; + + Tensor::Strides b_strides_; + + Tensor::Strides out_strides_; + + bool trans_a_{false}; + + bool trans_b_{false}; + + bool has_bias_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/mul.h b/src/base/mul.h new file mode 100644 index 0000000..9e7be22 --- /dev/null +++ b/src/base/mul.h @@ -0,0 +1,67 @@ +#ifndef INFINI_OPS_BASE_MUL_H_ +#define INFINI_OPS_BASE_MUL_H_ + +#include "operator.h" + +namespace infini::ops { + +class Mul : public Operator { + public: + Mul(const Tensor input, const Tensor other, Tensor out) + : ndim_{out.ndim()}, + output_size_{out.numel()}, + input_type_{input.dtype()}, + other_type_{other.dtype()}, + out_type_{out.dtype()}, + input_shape_{input.shape()}, + other_shape_{other.shape()}, + out_shape_{out.shape()}, + input_strides_{input.strides()}, + other_strides_{other.strides()}, + out_strides_{out.strides()}, + is_input_contiguous_{input.IsContiguous()}, + is_other_contiguous_{other.IsContiguous()}, + is_out_contiguous_{out.IsContiguous()} { + assert(!out.HasBroadcastDim() && + "the output of `Mul` should NOT have broadcasted dim!"); + assert(input_type_ == other_type_ && other_type_ == out_type_ && + "operator `Mul` requires all input and output tensors to have the " + "same dtype"); + } + + virtual void operator()(const Tensor input, const Tensor other, + Tensor out) const = 0; + + protected: + Tensor::Size ndim_{0}; + + Tensor::Size output_size_{0}; + + const DataType input_type_; + + const DataType other_type_; + + const DataType out_type_; + + Tensor::Shape input_shape_; + + Tensor::Shape other_shape_; + + Tensor::Shape out_shape_; + + Tensor::Strides input_strides_; + + Tensor::Strides other_strides_; + + Tensor::Strides out_strides_; + + bool is_input_contiguous_{false}; + + bool is_other_contiguous_{false}; + + bool is_out_contiguous_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/cast/cast.h b/src/cpu/cast/cast.h new file mode 100644 index 0000000..67c8367 --- /dev/null +++ b/src/cpu/cast/cast.h @@ -0,0 +1,57 @@ +#ifndef INFINI_OPS_CPU_CAST_CAST_H_ +#define INFINI_OPS_CPU_CAST_CAST_H_ + +#include "base/cast.h" +#include "common/generic_utils.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Cast { + public: + Operator(const Tensor input, Tensor out) : Cast{input, out} {} + + void operator()(const Tensor input, Tensor out) const override { + DispatchFunc( + input_dtype_, + [&](auto in_tag) { + using InT = typename decltype(in_tag)::type; + DispatchFunc( + out_dtype_, + [&](auto out_tag) { + using OutT = typename decltype(out_tag)::type; + Compute(input, out); + }, + "`Operator::operator()` (out)"); + }, + "`Operator::operator()` (in)"); + } + + private: + template + void Compute(const Tensor input, Tensor out) const { + const auto* in_ptr = static_cast(input.data()); + auto* out_ptr = static_cast(out.data()); + + auto get_idx = [&](Tensor::Size i, bool is_contig, const auto* shape, + const auto* strides) { + return is_contig ? i : utils::IndexToOffset(i, ndim_, shape, strides); + }; + +#pragma omp parallel for + for (Tensor::Size i = 0; i < output_size_; ++i) { + auto in_idx = get_idx(i, is_input_contiguous_, input_shape_.data(), + input_strides_.data()); + auto out_idx = get_idx(i, is_out_contiguous_, out_shape_.data(), + out_strides_.data()); + + out_ptr[out_idx] = + Caster::template Cast(in_ptr[in_idx]); + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/cat/cat.h b/src/cpu/cat/cat.h new file mode 100644 index 0000000..d49b023 --- /dev/null +++ b/src/cpu/cat/cat.h @@ -0,0 +1,68 @@ +#ifndef INFINI_OPS_CPU_CAT_CAT_H_ +#define INFINI_OPS_CPU_CAT_CAT_H_ + +#include +#include + +#include "base/cat.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Cat { + public: + Operator(const Tensor first_input, std::vector rest_inputs, + int64_t dim, Tensor out) + : Cat{first_input, rest_inputs, dim, out} {} + + void operator()(const Tensor first_input, std::vector rest_inputs, + int64_t dim, Tensor out) const override { + // Collect all input tensors. + std::vector inputs; + inputs.reserve(input_count_); + inputs.push_back(&first_input); + for (const auto& t : rest_inputs) { + inputs.push_back(&t); + } + + auto elem_size = kDataTypeToSize.at(out.dtype()); + auto ndim = out.ndim(); + auto out_shape = out.shape(); + + // Compute outer and inner sizes relative to the cat dimension. + Tensor::Size outer = 1; + for (int64_t i = 0; i < dim; ++i) { + outer *= out_shape[i]; + } + + Tensor::Size inner = 1; + for (size_t i = static_cast(dim) + 1; i < ndim; ++i) { + inner *= out_shape[i]; + } + + auto* out_ptr = static_cast(out.data()); + Tensor::Size out_dim_size = out_shape[dim]; + + // For each outer index, copy slices from each input along the cat dim. + for (Tensor::Size o = 0; o < outer; ++o) { + Tensor::Size offset_in_dim = 0; + + for (size_t t = 0; t < input_count_; ++t) { + auto in_dim = inputs[t]->shape()[dim]; + auto in_ptr = static_cast(inputs[t]->data()); + + auto src_offset = (o * in_dim) * inner * elem_size; + auto dst_offset = (o * out_dim_size + offset_in_dim) * inner * elem_size; + auto copy_size = in_dim * inner * elem_size; + + std::memcpy(out_ptr + dst_offset, in_ptr + src_offset, copy_size); + offset_in_dim += in_dim; + } + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/linear/linear.h b/src/cpu/linear/linear.h new file mode 100644 index 0000000..89f22fa --- /dev/null +++ b/src/cpu/linear/linear.h @@ -0,0 +1,112 @@ +#ifndef INFINI_OPS_CPU_LINEAR_LINEAR_H_ +#define INFINI_OPS_CPU_LINEAR_LINEAR_H_ + +#include + +#include "base/linear.h" +#include "common/generic_utils.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Linear, + Caster { + public: + Operator(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) + : Linear{a, b, bias, trans_a, trans_b, out} {} + + void operator()(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) const override { + DispatchFunc( + out.dtype(), + [&](auto tag) { + using T = typename decltype(tag)::type; + Compute(a, b, bias, trans_a, trans_b, out); + }, + "`Operator::operator()`"); + } + + private: + template + void Compute(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) const { + const auto* A = static_cast(a.data()); + const auto* B = static_cast(b.data()); + auto* Out = static_cast(out.data()); + const T* Bias = bias ? static_cast(bias->data()) : nullptr; + + // Determine M, K, N from shapes and transpose flags. + auto ndim_a = a_shape_.size(); + auto ndim_b = b_shape_.size(); + auto ndim_out = out_shape_.size(); + + Tensor::Size M = out_shape_[ndim_out - 2]; + Tensor::Size N = out_shape_[ndim_out - 1]; + Tensor::Size K = trans_a ? a_shape_[ndim_a - 2] : a_shape_[ndim_a - 1]; + + // Compute strides for the inner matrix dimensions after transpose. + Tensor::Stride stride_a_m = trans_a ? a_strides_[ndim_a - 1] + : a_strides_[ndim_a - 2]; + Tensor::Stride stride_a_k = trans_a ? a_strides_[ndim_a - 2] + : a_strides_[ndim_a - 1]; + Tensor::Stride stride_b_k = trans_b ? b_strides_[ndim_b - 1] + : b_strides_[ndim_b - 2]; + Tensor::Stride stride_b_n = trans_b ? b_strides_[ndim_b - 2] + : b_strides_[ndim_b - 1]; + Tensor::Stride stride_out_m = out_strides_[ndim_out - 2]; + Tensor::Stride stride_out_n = out_strides_[ndim_out - 1]; + + // Batch dimensions. + Tensor::Size batch_count = 1; + for (size_t i = 0; i + 2 < ndim_out; ++i) { + batch_count *= out_shape_[i]; + } + + Tensor::Stride batch_stride_a = + ndim_a > 2 ? a_strides_[ndim_a - 3] : 0; + Tensor::Stride batch_stride_b = + ndim_b > 2 ? b_strides_[ndim_b - 3] : 0; + Tensor::Stride batch_stride_out = + ndim_out > 2 ? out_strides_[ndim_out - 3] : 0; + + // Bias stride: for 1D bias [N], stride is 1. For batched bias, use last + // stride. + Tensor::Stride bias_stride = 0; + if (Bias && bias) { + auto ndim_bias = bias->shape().size(); + bias_stride = bias->strides()[ndim_bias - 1]; + } + + for (Tensor::Size batch = 0; batch < batch_count; ++batch) { + const auto* A_batch = A + batch * batch_stride_a; + const auto* B_batch = B + batch * batch_stride_b; + auto* Out_batch = Out + batch * batch_stride_out; + + for (Tensor::Size i = 0; i < M; ++i) { + for (Tensor::Size j = 0; j < N; ++j) { + float sum = 0.0f; + + for (Tensor::Size l = 0; l < K; ++l) { + float a_val = + Cast(A_batch[i * stride_a_m + l * stride_a_k]); + float b_val = + Cast(B_batch[l * stride_b_k + j * stride_b_n]); + sum += a_val * b_val; + } + + if (Bias) { + sum += Cast(Bias[j * bias_stride]); + } + + Out_batch[i * stride_out_m + j * stride_out_n] = Cast(sum); + } + } + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/mul/mul.h b/src/cpu/mul/mul.h new file mode 100644 index 0000000..0bdefb9 --- /dev/null +++ b/src/cpu/mul/mul.h @@ -0,0 +1,63 @@ +#ifndef INFINI_OPS_CPU_MUL_MUL_H_ +#define INFINI_OPS_CPU_MUL_MUL_H_ + +#include + +#include "base/mul.h" +#include "common/generic_utils.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Mul, + Caster { + public: + Operator(const Tensor input, const Tensor other, Tensor out) + : Mul{input, other, out} {} + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + DispatchFunc( + out_type_, + [&](auto tag) { + using T = typename decltype(tag)::type; + Compute(input, other, out); + }, + "`Operator::operator()`"); + } + + private: + template + void Compute(const Tensor input, const Tensor other, Tensor out) const { + using ComputeType = std::conditional_t || + IsFP16, + float, T>; + + const auto* input_ptr = static_cast(input.data()); + const auto* other_ptr = static_cast(other.data()); + auto* out_ptr = static_cast(out.data()); + + auto get_idx = [&](Tensor::Size i, bool is_contig, const auto* shape, + const auto* strides) { + return is_contig ? i : utils::IndexToOffset(i, ndim_, shape, strides); + }; + +#pragma omp parallel for + for (Tensor::Size i = 0; i < output_size_; ++i) { + auto input_idx = get_idx(i, is_input_contiguous_, input_shape_.data(), + input_strides_.data()); + auto other_idx = get_idx(i, is_other_contiguous_, other_shape_.data(), + other_strides_.data()); + auto out_idx = get_idx(i, is_out_contiguous_, out_shape_.data(), + out_strides_.data()); + + out_ptr[out_idx] = Cast(Cast(input_ptr[input_idx]) * + Cast(other_ptr[other_idx])); + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/hash.h b/src/hash.h index efb34f7..4721f33 100644 --- a/src/hash.h +++ b/src/hash.h @@ -2,6 +2,7 @@ #define INFINI_OPS_HASH_H_ #include +#include template inline void HashCombine(std::size_t& seed, const T& v) { @@ -9,4 +10,12 @@ inline void HashCombine(std::size_t& seed, const T& v) { seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); } +template +inline void HashCombine(std::size_t& seed, const std::vector& v) { + HashCombine(seed, v.size()); + for (const auto& elem : v) { + HashCombine(seed, elem); + } +} + #endif diff --git a/src/operator.h b/src/operator.h index 72e8337..99e584c 100644 --- a/src/operator.h +++ b/src/operator.h @@ -37,6 +37,14 @@ struct CacheKey { tensors.push_back(t); } + void Absorb(const std::vector& ts) { + HashCombine(hash, ts.size()); + for (const auto& t : ts) { + HashCombine(hash, t); + tensors.push_back(t); + } + } + template void Absorb(const T& v) { HashCombine(hash, v); diff --git a/src/pybind11_utils.h b/src/pybind11_utils.h index 766b6ea..acbb52b 100644 --- a/src/pybind11_utils.h +++ b/src/pybind11_utils.h @@ -118,10 +118,20 @@ inline Tensor TensorFromPybind11Handle(py::handle obj) { inline std::optional OptionalTensorFromPybind11Handle( const std::optional& obj) { - if (!obj.has_value()) return std::nullopt; + if (!obj.has_value() || obj->is_none()) return std::nullopt; return TensorFromPybind11Handle(*obj); } +inline std::vector VectorTensorFromPybind11Handle( + const std::vector& objs) { + std::vector result; + result.reserve(objs.size()); + for (const auto& obj : objs) { + result.push_back(TensorFromPybind11Handle(obj)); + } + return result; +} + } // namespace infini::ops #endif diff --git a/tests/test_add_rms_norm.py b/tests/test_add_rms_norm.py new file mode 100644 index 0000000..b2b7b87 --- /dev/null +++ b/tests/test_add_rms_norm.py @@ -0,0 +1,95 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, strides", + ( + ((1, 64), None), + ((2, 128), None), + ((4, 48, 64), None), + ((2, 4, 2048), None), + ((1, 64), (64, 1)), + ((4, 48, 64), (3072, 64, 1)), + ), +) +@pytest.mark.parametrize("eps", (1e-6, 1e-5)) +@pytest.mark.parametrize("implementation_index", (0, 1)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-4, 1e-4), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 2e-2, 1e-2), + ), +) +def test_add_rms_norm( + shape, + strides, + eps, + implementation_index, + dtype, + device, + rtol, + atol, +): + active_indices = infini.ops.AddRmsNorm.active_implementation_indices(device) + + if implementation_index not in active_indices: + pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") + + weight_shape = (shape[-1],) + x1 = randn_strided(shape, strides, dtype=dtype, device=device) + x2 = randn_strided(shape, strides, dtype=dtype, device=device) + gamma = randn_strided(weight_shape, None, dtype=dtype, device=device) + y_out = empty_strided(shape, strides, dtype=dtype, device=device) + x_out = empty_strided(shape, strides, dtype=dtype, device=device) + + return Payload( + lambda *args, **kwargs: _add_rms_norm( + *args, **kwargs, implementation_index=implementation_index + ), + _torch_add_rms_norm, + (x1, x2, gamma), + {"eps": eps, "y_out": y_out, "x_out": x_out}, + rtol=rtol, + atol=atol, + ) + + +def _add_rms_norm(x1, x2, gamma, *, eps=1e-6, y_out=None, x_out=None, + implementation_index=0): + if x1.device.type == "npu": + infini.ops.add_rms_norm( + x1, x2, gamma, eps, y_out, x_out, + implementation_index=implementation_index, + stream=get_npu_stream(x1), + ) + else: + infini.ops.add_rms_norm( + x1, x2, gamma, eps, y_out, x_out, + implementation_index=implementation_index, + ) + + # Concatenate both outputs into a single flat tensor for allclose comparison. + return torch.cat([y_out.contiguous().flatten(), x_out.contiguous().flatten()]) + + +def _torch_add_rms_norm(x1, x2, gamma, *, eps=1e-6, y_out=None, x_out=None): + x_sum = x1 + x2 + + if x_out is not None: + x_out.copy_(x_sum) + + rms = torch.sqrt(torch.mean(x_sum.float() * x_sum.float(), dim=-1, + keepdim=True) + eps) + y = (x_sum.float() / rms * gamma.float()).to(x1.dtype) + + if y_out is not None: + y_out.copy_(y) + + return torch.cat([y_out.contiguous().flatten(), x_out.contiguous().flatten()]) diff --git a/tests/test_cast.py b/tests/test_cast.py new file mode 100644 index 0000000..24b50ee --- /dev/null +++ b/tests/test_cast.py @@ -0,0 +1,65 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, out_strides", + ( + ((13, 4), None, None), + ((13, 4), (10, 1), (10, 1)), + ((13, 4, 4), None, None), + ((16, 5632), None, None), + ((4, 4, 5632), None, None), + ), +) +@pytest.mark.parametrize( + ("input_dtype", "out_dtype", "rtol", "atol"), + ( + (torch.float16, torch.float32, 1e-3, 1e-3), + (torch.float32, torch.float16, 1e-3, 1e-3), + (torch.bfloat16, torch.float32, 1e-2, 5e-3), + (torch.float32, torch.bfloat16, 1e-2, 5e-3), + (torch.float16, torch.bfloat16, 1e-2, 5e-3), + (torch.bfloat16, torch.float16, 1e-2, 5e-3), + ), +) +def test_cast( + shape, + input_strides, + out_strides, + input_dtype, + out_dtype, + device, + rtol, + atol, +): + input = randn_strided(shape, input_strides, dtype=input_dtype, device=device) + out = empty_strided(shape, out_strides, dtype=out_dtype, device=device) + + return Payload( + _cast, + _torch_cast, + (input, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _cast(input, out): + if input.device.type == "npu": + infini.ops.cast(input, out, stream=get_npu_stream(input)) + else: + infini.ops.cast(input, out) + + return out + + +def _torch_cast(input, out): + out.copy_(input.to(out.dtype)) + + return out diff --git a/tests/test_cat.py b/tests/test_cat.py new file mode 100644 index 0000000..dfdb059 --- /dev/null +++ b/tests/test_cat.py @@ -0,0 +1,72 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shapes, dim, out_shape", + ( + # 2 inputs, dim=0 + (((4, 64), (4, 64)), 0, (8, 64)), + # 2 inputs, dim=1 + (((4, 32), (4, 64)), 1, (4, 96)), + # 3 inputs, dim=1 + (((4, 16), (4, 32), (4, 16)), 1, (4, 64)), + # 2 inputs, dim=0, 3D + (((2, 4, 64), (2, 4, 64)), 0, (4, 4, 64)), + # 2 inputs, dim=2, 3D + (((2, 4, 32), (2, 4, 64)), 2, (2, 4, 96)), + # 4 inputs, dim=1 + (((1, 1024), (1, 1024), (1, 1024), (1, 1024)), 1, (1, 4096)), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +def test_cat(shapes, dim, out_shape, dtype, device, rtol, atol): + inputs = [ + randn_strided(s, None, dtype=dtype, device=device) for s in shapes + ] + out = empty_strided(out_shape, None, dtype=dtype, device=device) + + return Payload( + lambda *args: _cat(*args, dim=dim), + lambda *args: _torch_cat(*args, dim=dim), + (*inputs, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _cat(*args, dim): + inputs = list(args[:-1]) + out = args[-1] + + first = inputs[0] + rest = inputs[1:] + + if first.device.type == "npu": + infini.ops.cat(first, rest, dim, out, stream=get_npu_stream(first)) + else: + infini.ops.cat(first, rest, dim, out) + + return out + + +def _torch_cat(*args, dim): + inputs = list(args[:-1]) + out = args[-1] + + result = torch.cat(inputs, dim=dim) + out.copy_(result) + + return out diff --git a/tests/test_linear.py b/tests/test_linear.py new file mode 100644 index 0000000..33cd963 --- /dev/null +++ b/tests/test_linear.py @@ -0,0 +1,95 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "a_shape, b_shape, out_shape", + ( + ((4, 64), (64, 32), (4, 32)), + ((2, 128), (128, 256), (2, 256)), + ((1, 4096), (4096, 4096), (1, 4096)), + ((2, 4, 64), (2, 64, 32), (2, 4, 32)), + ((4, 8, 128), (4, 128, 64), (4, 8, 64)), + ), +) +@pytest.mark.parametrize("trans_a", (False, True)) +@pytest.mark.parametrize("trans_b", (False, True)) +@pytest.mark.parametrize("has_bias", (False, True)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-2, 5e-2), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 1e-2, 1e-2), + ), +) +def test_linear( + a_shape, + b_shape, + out_shape, + trans_a, + trans_b, + has_bias, + dtype, + device, + rtol, + atol, +): + a = randn_strided(a_shape, None, dtype=dtype, device=device) + b = randn_strided(b_shape, None, dtype=dtype, device=device) + + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + # Bias shape is [N], the last dim of the output. + bias = None + + if has_bias: + N = out_shape[-1] + bias = randn_strided((N,), None, dtype=dtype, device=device) + + out = empty_strided(out_shape, None, dtype=dtype, device=device) + + return Payload( + lambda *args: _linear(*args, trans_a=trans_a, trans_b=trans_b), + lambda *args: _torch_linear(*args, trans_a=trans_a, trans_b=trans_b), + (a, b, bias, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _linear(a, b, bias, out, trans_a=False, trans_b=False): + if a.device.type == "npu": + infini.ops.linear( + a, b, bias, trans_a, trans_b, out, stream=get_npu_stream(a) + ) + else: + infini.ops.linear(a, b, bias, trans_a, trans_b, out) + + return out + + +def _torch_linear(a, b, bias, out, trans_a=False, trans_b=False): + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + result = torch.matmul(a.float(), b.float()) + + if bias is not None: + result = result + bias.float() + + out.copy_(result.to(out.dtype)) + + return out diff --git a/tests/test_matmul.py b/tests/test_matmul.py new file mode 100644 index 0000000..dae3961 --- /dev/null +++ b/tests/test_matmul.py @@ -0,0 +1,79 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "a_shape, b_shape, c_shape", + ( + ((4, 64), (64, 32), (4, 32)), + ((2, 128), (128, 256), (2, 256)), + ((2, 4, 64), (2, 64, 32), (2, 4, 32)), + ((4, 8, 128), (4, 128, 64), (4, 8, 64)), + ), +) +@pytest.mark.parametrize("trans_a", (False, True)) +@pytest.mark.parametrize("trans_b", (False, True)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-2, 1e-2), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 1e-2, 1e-2), + ), +) +def test_matmul( + a_shape, + b_shape, + c_shape, + trans_a, + trans_b, + dtype, + device, + rtol, + atol, +): + a = randn_strided(a_shape, None, dtype=dtype, device=device) + b = randn_strided(b_shape, None, dtype=dtype, device=device) + + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + c = empty_strided(c_shape, None, dtype=dtype, device=device) + + return Payload( + lambda *args: _matmul(*args, trans_a=trans_a, trans_b=trans_b), + lambda *args: _torch_matmul(*args, trans_a=trans_a, trans_b=trans_b), + (a, b, c), + {}, + rtol=rtol, + atol=atol, + ) + + +def _matmul(a, b, c, trans_a=False, trans_b=False): + if a.device.type == "npu": + infini.ops.matmul(a, b, c, trans_a, trans_b, stream=get_npu_stream(a)) + else: + infini.ops.matmul(a, b, c, trans_a, trans_b) + + return c + + +def _torch_matmul(a, b, c, trans_a=False, trans_b=False): + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + result = torch.matmul(a.float(), b.float()).to(c.dtype) + c.copy_(result) + + return c diff --git a/tests/test_mul.py b/tests/test_mul.py new file mode 100644 index 0000000..ea7f918 --- /dev/null +++ b/tests/test_mul.py @@ -0,0 +1,90 @@ +import infini.ops +import pytest +import torch + +from tests.utils import ( + Payload, + empty_strided, + get_npu_stream, + randint_strided, + randn_strided, +) + +_INT_DTYPES = (torch.int16, torch.int32, torch.int64) + +_UINT_DTYPES = tuple( + filter(None, (getattr(torch, f"uint{bits}", None) for bits in (16, 32, 64))) +) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, other_strides, out_strides", + ( + ((13, 4), None, None, None), + ((13, 4), (10, 1), (10, 1), (10, 1)), + ((13, 4), (0, 1), None, None), + ((13, 4, 4), None, None, None), + ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)), + ((13, 4, 4), (4, 0, 1), (0, 4, 1), None), + ((16, 5632), None, None, None), + ((16, 5632), (13312, 1), (13312, 1), (13312, 1)), + ((13, 16, 2), (128, 4, 1), (0, 2, 1), (64, 4, 1)), + ((13, 16, 2), (128, 4, 1), (2, 0, 1), (64, 4, 1)), + ((4, 4, 5632), None, None, None), + ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ) + + tuple((dtype, 0, 0) for dtype in _INT_DTYPES + _UINT_DTYPES), +) +def test_mul( + shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol +): + if device == "musa" and dtype in _UINT_DTYPES: + pytest.skip( + "The `torch.musa` test cloning path does not support `uint16`, `uint32`, or `uint64`." + ) + + if dtype in _INT_DTYPES or dtype in _UINT_DTYPES: + input = randint_strided( + 0, 100, shape, input_strides, dtype=dtype, device=device + ) + other = randint_strided( + 0, 100, shape, other_strides, dtype=dtype, device=device + ) + else: + input = randn_strided(shape, input_strides, dtype=dtype, device=device) + other = randn_strided(shape, other_strides, dtype=dtype, device=device) + + out = empty_strided(shape, out_strides, dtype=dtype, device=device) + + return Payload(_mul, _torch_mul, (input, other, out), {}, rtol=rtol, atol=atol) + + +def _mul(input, other, out): + if input.device.type == "npu": + infini.ops.mul(input, other, out, stream=get_npu_stream(input)) + else: + infini.ops.mul(input, other, out) + + return out + + +def _torch_mul(input, other, out): + if input.dtype in _UINT_DTYPES: + input = input.to(torch.int64) + + if other.dtype in _UINT_DTYPES: + other = other.to(torch.int64) + + res = torch.mul(input, other) + out.copy_(res.to(out.dtype)) + + return out diff --git a/tests/test_rotary_embedding.py b/tests/test_rotary_embedding.py index 733ae43..d2a7c93 100644 --- a/tests/test_rotary_embedding.py +++ b/tests/test_rotary_embedding.py @@ -115,6 +115,16 @@ def test_rotary_embedding_full( if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): pytest.skip("NPU not available") + if device == "npu" and not is_neox_style: + pytest.skip( + "Ascend aclnnApplyRotaryPosEmbV2 only supports neox style " + "(rotaryMode='half')" + ) + + # aclnnApplyRotaryPosEmbV2 accumulates with ~4 ULP error for float16. + if device == "npu" and dtype == torch.float16: + atol = 0.01 + num_kv_heads = num_heads rotary_dim = head_size num_tokens = 16 @@ -207,6 +217,11 @@ def test_rotary_embedding_partial( if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): pytest.skip("NPU not available") + if device == "npu": + pytest.skip( + "Ascend aclnnApplyRotaryPosEmbV2 requires rotary_dim == head_size" + ) + num_tokens = 16 max_seq_len = 64 From 4f90b5a807a9489a7f9fa484798a63af369864dd Mon Sep 17 00:00:00 2001 From: zhangyue Date: Thu, 9 Apr 2026 17:04:34 +0800 Subject: [PATCH 7/8] fix(ascend): stabilize `WorkspacePool` pointers and remove dead code Use `unique_ptr` in the arena map so that thread-local cached pointers remain valid across `unordered_map` rehashes. Remove unused `detail::reshapeView` helper from FlashAttention. --- src/ascend/flash_attention/kernel.h | 23 --------------------- src/ascend/workspace_pool_.h | 32 ++++++++++++++++++----------- 2 files changed, 20 insertions(+), 35 deletions(-) diff --git a/src/ascend/flash_attention/kernel.h b/src/ascend/flash_attention/kernel.h index 3dae947..d8545d9 100644 --- a/src/ascend/flash_attention/kernel.h +++ b/src/ascend/flash_attention/kernel.h @@ -17,29 +17,6 @@ namespace infini::ops { namespace detail { -// Build an aclTensor with a different view shape/stride but the same data -// pointer. -inline aclTensor* reshapeView(const Tensor& t, - const std::vector& new_shape, - const std::vector& new_strides) { - int64_t storage_elems = 1; - for (size_t i = 0; i < new_shape.size(); ++i) { - if (new_shape[i] == 0) { - storage_elems = 0; - break; - } - if (new_strides[i] > 0 && new_shape[i] > 1) { - storage_elems += static_cast(new_shape[i] - 1) * new_strides[i]; - } - } - std::vector storage_shape = {storage_elems}; - return aclCreateTensor( - new_shape.data(), static_cast(new_shape.size()), - ascend::toAclDtype(t.dtype()), new_strides.data(), 0, ACL_FORMAT_ND, - storage_shape.data(), static_cast(storage_shape.size()), - const_cast(t.data())); -} - // Extract cu_seqlens differences to a host aclIntArray. // cu_seqlens = [0, s1, s1+s2, ...] -> per_seq_lens = [s1, s2, ...]. // Used by paged decode (actualSeqLengthsKv = per-sequence KV lengths). diff --git a/src/ascend/workspace_pool_.h b/src/ascend/workspace_pool_.h index 3d0f698..ecbfb69 100644 --- a/src/ascend/workspace_pool_.h +++ b/src/ascend/workspace_pool_.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -37,45 +38,52 @@ class WorkspacePool { } // Slow path: look up arena in the map under lock. + // Arenas are heap-allocated via `unique_ptr` so that pointers remain stable + // across `unordered_map` rehashes (which invalidate value references). std::lock_guard lock(mutex_); - auto& arena = arenas_[stream]; - if (needed > arena.capacity) { - if (arena.capacity > 0) { + auto& slot = arenas_[stream]; + if (!slot) { + slot = std::make_unique(); + } + auto* arena = slot.get(); + if (needed > arena->capacity) { + if (arena->capacity > 0) { aclrtSynchronizeStream(stream); - aclrtFree(arena.buf); + aclrtFree(arena->buf); } if (needed > 0) { - auto ret = aclrtMalloc(&arena.buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY); + auto ret = + aclrtMalloc(&arena->buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY); assert(ret == ACL_SUCCESS && "`WorkspacePool`: `aclrtMalloc` failed"); } - arena.capacity = needed; + arena->capacity = needed; } last_stream = stream; - last_arena = &arena; - return arena; + last_arena = arena; + return *arena; } ~WorkspacePool() { for (auto& [stream, arena] : arenas_) { - if (arena.capacity > 0) { + if (arena && arena->capacity > 0) { // The CANN runtime may already be torn down when this static // destructor runs. aclrtGetDevice fails in that case — skip the // free to avoid glibc "double free" abort. int32_t dev_id = -1; if (aclrtGetDevice(&dev_id) == ACL_SUCCESS) { - aclrtFree(arena.buf); + aclrtFree(arena->buf); } else { fprintf(stderr, "[InfiniOps] `WorkspacePool`: CANN runtime already finalized, " "skipping `aclrtFree` (%" PRIu64 " bytes leaked).\n", - arena.capacity); + arena->capacity); } } } } private: - std::unordered_map arenas_; + std::unordered_map> arenas_; std::mutex mutex_; }; From 3f43d577ace9e709f6fd119a35497f5fdbb94d12 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Thu, 9 Apr 2026 17:20:30 +0800 Subject: [PATCH 8/8] fix(cat): support negative dim and document TensorList caching assumption MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Normalize negative `dim` in the base class constructor (e.g. -1 → last dimension). Add comment in the Ascend kernel explaining why `aclSetRawTensorAddr` on TensorList-contained descriptors is sufficient without `aclSetInputTensorAddr`. Add negative-dim test case. --- src/ascend/cat/kernel.h | 7 +++++-- src/base/cat.h | 9 +++++---- src/cpu/cat/cat.h | 4 +++- tests/test_cat.py | 2 ++ 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/ascend/cat/kernel.h b/src/ascend/cat/kernel.h index a847b92..aae90e0 100644 --- a/src/ascend/cat/kernel.h +++ b/src/ascend/cat/kernel.h @@ -34,7 +34,7 @@ class Operator : public Cat { } void operator()(const Tensor first_input, std::vector rest_inputs, - int64_t dim, Tensor out) const override { + int64_t /*dim*/, Tensor out) const override { auto stream = static_cast(stream_); // Collect all input tensors in order. @@ -63,7 +63,10 @@ class Operator : public Cat { &executor_); aclSetAclOpExecutorRepeatable(executor_); } else { - // Subsequent calls: update data pointers on cached descriptors. + // Subsequent calls: update data pointers on cached descriptors via + // `aclSetRawTensorAddr`. The executor holds references to the same + // `aclTensor*` objects inside `tensor_list_`, so updating their data + // pointers is sufficient — no `aclSetInputTensorAddr` needed. for (size_t i = 0; i < input_count_; ++i) { in_caches_[i].get(const_cast(inputs[i]->data())); } diff --git a/src/base/cat.h b/src/base/cat.h index 16f9bd2..6d16d12 100644 --- a/src/base/cat.h +++ b/src/base/cat.h @@ -11,12 +11,13 @@ class Cat : public Operator { public: Cat(const Tensor first_input, std::vector rest_inputs, int64_t dim, Tensor out) - : dim_{dim}, input_count_{1 + rest_inputs.size()} { + : input_count_{1 + rest_inputs.size()} { assert(input_count_ >= 2 && "Cat requires at least 2 input tensors"); - auto ndim = out.ndim(); - assert(dim >= 0 && dim < static_cast(ndim) && - "Cat dim out of range"); + auto ndim = static_cast(out.ndim()); + // Normalize negative dim (e.g. -1 means last dimension). + dim_ = dim < 0 ? dim + ndim : dim; + assert(dim_ >= 0 && dim_ < ndim && "Cat dim out of range"); } virtual void operator()(const Tensor first_input, diff --git a/src/cpu/cat/cat.h b/src/cpu/cat/cat.h index d49b023..ed3f41d 100644 --- a/src/cpu/cat/cat.h +++ b/src/cpu/cat/cat.h @@ -17,7 +17,7 @@ class Operator : public Cat { : Cat{first_input, rest_inputs, dim, out} {} void operator()(const Tensor first_input, std::vector rest_inputs, - int64_t dim, Tensor out) const override { + int64_t /*dim*/, Tensor out) const override { // Collect all input tensors. std::vector inputs; inputs.reserve(input_count_); @@ -26,6 +26,8 @@ class Operator : public Cat { inputs.push_back(&t); } + // Use normalized `dim_` from base class (handles negative dim). + auto dim = dim_; auto elem_size = kDataTypeToSize.at(out.dtype()); auto ndim = out.ndim(); auto out_shape = out.shape(); diff --git a/tests/test_cat.py b/tests/test_cat.py index dfdb059..9346802 100644 --- a/tests/test_cat.py +++ b/tests/test_cat.py @@ -13,6 +13,8 @@ (((4, 64), (4, 64)), 0, (8, 64)), # 2 inputs, dim=1 (((4, 32), (4, 64)), 1, (4, 96)), + # 2 inputs, dim=-1 (negative dim) + (((4, 32), (4, 64)), -1, (4, 96)), # 3 inputs, dim=1 (((4, 16), (4, 32), (4, 16)), 1, (4, 64)), # 2 inputs, dim=0, 3D