diff --git a/.ci/README.md b/.ci/README.md index 190d012..f44b5a3 100644 --- a/.ci/README.md +++ b/.ci/README.md @@ -158,7 +158,7 @@ Platform is auto-detected (via `nvidia-smi`/`ixsmi`/`mx-smi`/`mthreads-gmi`/`cnm | `--stage` | Run only the specified stage | | `--image-tag` | Override image tag | | `--gpu-id` | Override GPU device IDs (nvidia via `--gpus`, others via `CUDA_VISIBLE_DEVICES`) | -| `--test` | Override pytest test path (e.g., `tests/test_gemm.py::test_gemm`) | +| `--test` | Replace stage command entirely (e.g., `pytest tests/test_add.py -v`) | | `--results-dir` | Host directory mounted to `/workspace/results` inside the container | | `--local` | Mount current directory (read-only) instead of cloning from git | | `--dry-run` | Print docker command without executing | @@ -195,7 +195,7 @@ Proxy vars are forwarded from the host. Test results are written to `--results-d | MetaX | `--privileged` | `none` | `maca-pytorch:3.2.1.4-...` | `mx-smi` | | Moore | `--privileged` | `none` | `vllm_musa:20251112_hygon` | `mthreads-gmi` | | Cambricon | `--privileged` | `mlu` | `cambricon/pytorch:v1.25.3` | `cnmon` | -| Ascend | TODO | — | `ascend-pytorch:24.0.0` | — | +| Ascend | `--privileged` + device mounts | `npu` | `ascend-pytorch:24.0.RC3-A2-2.1.0` | `npu-smi` | `gpu_style` controls the Docker device injection mechanism: `nvidia` uses `--gpus`, `none` uses `CUDA_VISIBLE_DEVICES` (or skips injection for Moore), `mlu` uses `MLU_VISIBLE_DEVICES`. diff --git a/.ci/ci_resource.py b/.ci/ci_resource.py index 51b181f..de2953d 100644 --- a/.ci/ci_resource.py +++ b/.ci/ci_resource.py @@ -14,6 +14,7 @@ GPU_STYLE_NVIDIA = "nvidia" GPU_STYLE_NONE = "none" GPU_STYLE_MLU = "mlu" +GPU_STYLE_NPU = "npu" @dataclass @@ -44,6 +45,7 @@ class ResourcePool: "metax": "mx-smi", "moore": "mthreads-gmi", "cambricon": "cnmon", + "ascend": "npu-smi", } def __init__(self, platform, utilization_threshold=10): @@ -72,6 +74,9 @@ def detect_gpus(self) -> list[GpuInfo]: if self._platform == "cambricon": return self._detect_gpus_cambricon() + if self._platform == "ascend": + return self._detect_gpus_ascend() + tool = self.GPU_QUERY_TOOLS.get(self._platform) if not tool: @@ -325,6 +330,73 @@ def _detect_gpus_cambricon(self) -> list[GpuInfo]: return sorted(gpus, key=operator.attrgetter("index")) + def _detect_gpus_ascend(self) -> list[GpuInfo]: + """Parse npu-smi info output for Huawei Ascend NPUs. + + Output format (pipe-delimited table, two rows per NPU): + | 0 910B4 | OK | 86.5 41 ... + | 0 | 0000:C1:00.0 | 0 0 / 0 2789 / 32768 | + Row 1: index, name, health, power, temp, hugepages. + Row 2: chip_id, bus_id, aicore_util, memory_usage, hbm_usage. + """ + try: + result = subprocess.run( + ["npu-smi", "info"], + capture_output=True, + text=True, + timeout=10, + ) + except (FileNotFoundError, subprocess.TimeoutExpired): + return [] + + if result.returncode != 0: + return [] + + gpus = [] + lines = result.stdout.splitlines() + i = 0 + + while i < len(lines): + line = lines[i] + # Match row 1: "| {index} {name} ..." + m1 = re.match(r"^\|\s+(\d+)\s+", line) + + if m1 and i + 1 < len(lines): + try: + npu_index = int(m1.group(1)) + aicore_m = re.match( + r"^\|\s+\d+\s+\|\s+[\da-f:.]+\s+\|\s*([\d.]+)\s", lines[i + 1] + ) + + util_pct = float(aicore_m.group(1)) if aicore_m else 0.0 + + # Parse HBM usage from row 2: "{used} / {total}". + hbm_m = re.search(r"([\d.]+)\s*/\s*([\d.]+)", lines[i + 1]) + + if hbm_m: + used_mb = float(hbm_m.group(1)) + total_mb = float(hbm_m.group(2)) + else: + used_mb, total_mb = 0.0, 0.0 + + gpus.append( + GpuInfo( + index=npu_index, + memory_used_mb=used_mb, + memory_total_mb=total_mb, + utilization_pct=util_pct, + ) + ) + except (ValueError, AttributeError): + pass + + i += 2 + continue + + i += 1 + + return sorted(gpus, key=operator.attrgetter("index")) + def detect_system_resources(self) -> SystemResources: """Read system memory from /proc/meminfo and CPU count.""" total_mb = 0.0 diff --git a/.ci/config.yaml b/.ci/config.yaml index b70e7df..a6a5e70 100644 --- a/.ci/config.yaml +++ b/.ci/config.yaml @@ -137,10 +137,34 @@ platforms: - name: test run: pytest tests/test_gemm.py -n 4 -v --tb=short --junitxml=/workspace/results/test-results.xml - ascend: # TODO: Ascend image is not ready yet + ascend: image: dockerfile: .ci/images/ascend/ build_args: - BASE_IMAGE: ascendhub.huawei.com/public-ascendhub/ascend-pytorch:24.0.0 - private_sdk: - source_env: PRIVATE_SDK_URL + BASE_IMAGE: quay.io/ascend/vllm-ascend:v0.18.0rc1-openeuler + PIP_INDEX_URL: https://pypi.org/simple + docker_args: + - "--runtime=runc" + - "--privileged" + - "--device=/dev/davinci0" + - "--device=/dev/davinci_manager" + - "--device=/dev/devmm_svm" + - "--device=/dev/hisi_hdc" + volumes: + - /usr/local/Ascend/driver:/usr/local/Ascend/driver:ro + - /usr/local/dcmi:/usr/local/dcmi:ro + - /usr/local/bin/npu-smi:/usr/local/bin/npu-smi:ro + env: + ASCEND_HOME_PATH: /usr/local/Ascend/ascend-toolkit/latest + setup: pip install .[dev] --no-build-isolation + jobs: + npu: + resources: + gpu_ids: "0" + gpu_style: npu + memory: 32GB + shm_size: 16g + timeout: 3600 + stages: + - name: test + run: pytest tests/ -n 1 -k npu -v --tb=short --junitxml=/workspace/results/test-results.xml diff --git a/.ci/images/ascend/Dockerfile b/.ci/images/ascend/Dockerfile index 66392eb..5391d7d 100644 --- a/.ci/images/ascend/Dockerfile +++ b/.ci/images/ascend/Dockerfile @@ -1,7 +1,7 @@ ARG BASE_IMAGE FROM ${BASE_IMAGE} -ENV DEBIAN_FRONTEND=noninteractive +USER root ARG HTTP_PROXY ARG HTTPS_PROXY @@ -10,30 +10,22 @@ ARG http_proxy ARG https_proxy ARG no_proxy -RUN apt-get update && \ - apt-get install -y --no-install-recommends \ - git \ - cmake \ - ninja-build \ - coreutils \ - curl \ - libclang-dev \ - && rm -rf /var/lib/apt/lists/* - -ARG PRIVATE_SDK_URL -RUN if [ -n "$PRIVATE_SDK_URL" ]; then \ - curl -fSL "$PRIVATE_SDK_URL" -o /tmp/sdk.run && \ - chmod +x /tmp/sdk.run && /tmp/sdk.run --quiet && \ - rm /tmp/sdk.run; \ - fi - -RUN pip install --no-cache-dir \ +ARG PIP_INDEX_URL +RUN pip install --no-cache-dir --progress-bar off \ + ${PIP_INDEX_URL:+--index-url "$PIP_INDEX_URL"} \ + libclang \ + ninja \ scikit-build-core \ pybind11 \ - libclang \ pytest \ pytest-cov \ pytest-xdist \ - pyyaml + ruff==0.15.7 + +# Pin pre-installed torch to prevent pip from replacing it. +RUN pip show torch >/dev/null 2>&1 && \ + echo "torch==$(pip show torch | grep '^Version:' | awk '{print $2}')" > /etc/pip-constraints.txt || \ + touch /etc/pip-constraints.txt +ENV PIP_CONSTRAINT=/etc/pip-constraints.txt WORKDIR /workspace diff --git a/.ci/run.py b/.ci/run.py index 24a8867..092d338 100644 --- a/.ci/run.py +++ b/.ci/run.py @@ -13,47 +13,19 @@ GPU_STYLE_NVIDIA, GPU_STYLE_NONE, GPU_STYLE_MLU, + GPU_STYLE_NPU, ResourcePool, detect_platform, ) from utils import get_git_commit, load_config -# Flags that consume the next token as their value (e.g. -n 4, -k expr). -_PYTEST_VALUE_FLAGS = {"-n", "-k", "-m", "-p", "--tb", "--junitxml", "--rootdir"} +def apply_test_override(run_cmd, test_cmd): + """Replace a stage command with *test_cmd*. - -def apply_test_override(run_cmd, test_path): - """Replace positional test path(s) in a pytest stage command. - - For example: ``pytest tests/ -n 4 ...`` becomes - ``pytest tests/test_gemm.py -n 4 ...`` when ``test_path`` is - ``tests/test_gemm.py``. + ``--test`` always replaces the entire stage command regardless of whether + the original is pytest or something else. """ - parts = shlex.split(run_cmd) - - if not parts or parts[0] != "pytest": - return run_cmd - - result = ["pytest", test_path] - skip_next = False - - for p in parts[1:]: - if skip_next: - result.append(p) - skip_next = False - continue - - if p.startswith("-"): - result.append(p) - if p in _PYTEST_VALUE_FLAGS: - skip_next = True - continue - - # Skip existing test paths; the override is already in result[1]. - if not ("/" in p or p.endswith(".py") or "::" in p): - result.append(p) - - return shlex.join(result) + return test_cmd def build_results_dir(base, platform, stages, commit): @@ -212,6 +184,9 @@ def build_docker_args( # For Cambricon MLU platforms that use --privileged, # control visible devices via MLU_VISIBLE_DEVICES. args.extend(["-e", f"MLU_VISIBLE_DEVICES={gpu_id}"]) + elif gpu_style == GPU_STYLE_NPU and gpu_id and gpu_id != "all": + # Ascend: control visible NPU via ASCEND_VISIBLE_DEVICES. + args.extend(["-e", f"ASCEND_VISIBLE_DEVICES={gpu_id}"]) memory = resources.get("memory") @@ -315,7 +290,7 @@ def main(): parser.add_argument( "--test", type=str, - help='Override pytest test path, e.g. "tests/test_gemm.py" or "tests/test_gemm.py::test_gemm"', + help='Replace stage command with this (e.g. "pytest tests/test_add.py -v")', ) parser.add_argument( "--local", diff --git a/.ci/tests/test_resource.py b/.ci/tests/test_resource.py index cbe37d8..0db3fbb 100644 --- a/.ci/tests/test_resource.py +++ b/.ci/tests/test_resource.py @@ -93,6 +93,7 @@ def test_detect_system_resources(monkeypatch, tmp_path): "MemAvailable: 20000000 kB\n" ) + _real_open = open def fake_open(path, **kw): diff --git a/.ci/tests/test_run.py b/.ci/tests/test_run.py index 93987e5..65c6de6 100644 --- a/.ci/tests/test_run.py +++ b/.ci/tests/test_run.py @@ -296,3 +296,36 @@ def test_build_results_dir_under_base(): stages = [{"name": "test", "run": "pytest"}] d = run.build_results_dir("/tmp/my-results", "ascend", stages, "def5678") assert d.parent == Path("/tmp/my-results") + + +# --------------------------------------------------------------------------- +# Tests for `apply_test_override`. +# --------------------------------------------------------------------------- + + +def test_apply_test_override_replaces_pytest_command(): + assert run.apply_test_override("pytest tests/ -v", "pytest tests/test_add.py") == ( + "pytest tests/test_add.py" + ) + + +def test_apply_test_override_replaces_non_pytest_command(): + assert run.apply_test_override("ruff check .", "python docs/repro.py") == ( + "python docs/repro.py" + ) + + +def test_apply_test_override_replaces_empty_command(): + assert run.apply_test_override("", "bash script.sh") == "bash script.sh" + + +def test_apply_test_override_preserves_user_flags(): + cmd = "pytest tests/test_gemm.py -n 1 -v --tb=short" + assert run.apply_test_override("pytest tests/ -n 4", cmd) == cmd + + +def test_apply_test_override_with_shell_command(): + assert ( + run.apply_test_override("pytest tests/", "cd /tmp && python repro.py") + == "cd /tmp && python repro.py" + ) diff --git a/src/ascend/add/kernel.h b/src/ascend/add/kernel.h new file mode 100644 index 0000000..e81f9bd --- /dev/null +++ b/src/ascend/add/kernel.h @@ -0,0 +1,58 @@ +#ifndef INFINI_OPS_ASCEND_ADD_KERNEL_H_ +#define INFINI_OPS_ASCEND_ADD_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_add.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/add.h" +#include "data_type.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Add { + public: + Operator(const Tensor input, const Tensor other, Tensor out) + : Add(input, other, out) { + // aclCreateScalar stores the pointer rather than copying the value, so + // alpha_storage_* must remain alive for the lifetime of alpha_. + // The alpha scalar type must match the tensor dtype: use int64 for integer + // dtypes and float for floating-point dtypes. + if (ascend::isIntegerDtype(input.dtype())) { + alpha_ = aclCreateScalar(&alpha_int_storage_, ACL_INT64); + } else { + alpha_ = aclCreateScalar(&alpha_float_storage_, ACL_FLOAT); + } + } + + ~Operator() { aclDestroyScalar(alpha_); } + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + auto stream = static_cast(stream_); + auto t_in = ascend::buildAclTensor(input); + auto t_oth = ascend::buildAclTensor(other); + auto t_out = ascend::buildAclTensor(out); + uint64_t ws_needed = 0; + aclOpExecutor* executor = nullptr; + aclnnAddGetWorkspaceSize(t_in, t_oth, alpha_, t_out, &ws_needed, &executor); + auto& arena = ascend::workspacePool().ensure(stream, ws_needed); + aclnnAdd(arena.buf, ws_needed, executor, stream); + aclDestroyTensor(t_in); + aclDestroyTensor(t_oth); + aclDestroyTensor(t_out); + } + + private: + float alpha_float_storage_ = + 1.0f; // stable address for aclCreateScalar (float) + int64_t alpha_int_storage_ = 1; // stable address for aclCreateScalar (int) + aclScalar* alpha_ = nullptr; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/add_rms_norm/kernel.h b/src/ascend/add_rms_norm/kernel.h new file mode 100644 index 0000000..28ae702 --- /dev/null +++ b/src/ascend/add_rms_norm/kernel.h @@ -0,0 +1,64 @@ +#ifndef INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_add_rms_norm.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/add_rms_norm.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public AddRmsNorm { + public: + Operator(const Tensor x1, const Tensor x2, const Tensor gamma, float eps, + Tensor y_out, Tensor x_out) + : AddRmsNorm(x1, x2, gamma, eps, y_out, x_out) { + // aclnnAddRmsNorm writes rstd as a required side output. + // Allocate a persistent device buffer for it. + size_t rstd_bytes = batch_size_ * nhead_ * sizeof(float); + aclrtMalloc(&rstd_data_, rstd_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + } + + ~Operator() { + if (rstd_data_) aclrtFree(rstd_data_); + } + + void operator()(const Tensor x1, const Tensor x2, const Tensor gamma, + float eps, Tensor y_out, Tensor x_out) const override { + auto t_x1 = ascend::buildAclTensor(x1); + auto t_x2 = ascend::buildAclTensor(x2); + auto t_gamma = ascend::buildAclTensor(gamma); + auto t_y_out = ascend::buildAclTensor(y_out); + auto t_x_out = ascend::buildAclTensor(x_out); + // rstd is always float32 regardless of input dtype. + auto t_rstd = aclCreateTensor(rstd_shape_.data(), 2, ACL_FLOAT, + /*strides=*/nullptr, 0, ACL_FORMAT_ND, + rstd_shape_.data(), 2, rstd_data_); + uint64_t ws_needed = 0; + aclOpExecutor* executor = nullptr; + aclnnAddRmsNormGetWorkspaceSize(t_x1, t_x2, t_gamma, eps, t_y_out, t_rstd, + t_x_out, &ws_needed, &executor); + auto stream = static_cast(stream_); + auto& arena = ascend::workspacePool().ensure(stream, ws_needed); + aclnnAddRmsNorm(arena.buf, ws_needed, executor, stream); + aclDestroyTensor(t_x1); + aclDestroyTensor(t_x2); + aclDestroyTensor(t_gamma); + aclDestroyTensor(t_y_out); + aclDestroyTensor(t_rstd); + aclDestroyTensor(t_x_out); + } + + private: + void* rstd_data_ = nullptr; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/causal_softmax/kernel.h b/src/ascend/causal_softmax/kernel.h new file mode 100644 index 0000000..5883c42 --- /dev/null +++ b/src/ascend/causal_softmax/kernel.h @@ -0,0 +1,127 @@ +#ifndef INFINI_OPS_ASCEND_CAUSAL_SOFTMAX_KERNEL_H_ +#define INFINI_OPS_ASCEND_CAUSAL_SOFTMAX_KERNEL_H_ + +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_copy.h" +#include "aclnn_masked_fill_scalar.h" +#include "aclnn_softmax.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/causal_softmax.h" +#include "data_type.h" +#include "operator.h" + +namespace infini::ops { + +// Implements causal softmax via three ACLNN calls: +// 1. InplaceCopy(temp, input) — stride-aware copy to contiguous temp +// buffer. +// 2. InplaceMaskedFillScalar(temp, mask, -inf) — apply upper-triangle mask. +// 3. Softmax(temp, dim=-1, out) — softmax over the last dimension. +// +// The boolean causal mask is pre-computed and uploaded to device once in the +// constructor. Its shape (seq_len, total_seq_len) broadcasts over the batch. +template <> +class Operator : public CausalSoftmax { + public: + Operator(const Tensor input, Tensor out) : CausalSoftmax(input, out) { + // Contiguous temp buffer with the same element count as input. + size_t n_elems = input.numel(); + size_t elem_bytes = kDataTypeToSize.at(dtype_); + aclrtMalloc(&temp_buf_, n_elems * elem_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + // Build a contiguous Tensor descriptor pointing to temp_buf_. + Tensor temp_t{temp_buf_, input.shape(), input.dtype(), input.device()}; + + // Causal mask: mask[i][j] = 1 when position j must be masked for query i. + // Shape (seq_len, total_seq_len) – broadcasts over the batch dimension. + size_t mask_elems = seq_len_ * total_seq_len_; + std::vector mask_host(mask_elems, 0); + + for (size_t i = 0; i < seq_len_; ++i) { + auto vis_end = static_cast(total_seq_len_ - seq_len_ + i); + + for (auto j = vis_end + 1; j < static_cast(total_seq_len_); + ++j) { + mask_host[i * total_seq_len_ + j] = 1; + } + } + + aclrtMalloc(&mask_buf_, mask_elems, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMemcpy(mask_buf_, mask_elems, mask_host.data(), mask_elems, + ACL_MEMCPY_HOST_TO_DEVICE); + + std::vector mshape = {static_cast(seq_len_), + static_cast(total_seq_len_)}; + std::vector mstrides = {static_cast(total_seq_len_), 1}; + mask_tensor_ = aclCreateTensor(mshape.data(), mshape.size(), ACL_BOOL, + mstrides.data(), 0, ACL_FORMAT_ND, + mshape.data(), mshape.size(), mask_buf_); + + // Scalar -inf for the masked-fill step. aclCreateScalar stores the pointer + // rather than copying, so neg_inf_storage_ must stay alive with the object. + neg_inf_ = aclCreateScalar(&neg_inf_storage_, ACL_FLOAT); + // Workspaces are allocated lazily on first operator() call. + } + + ~Operator() { + aclrtFree(temp_buf_); + aclrtFree(mask_buf_); + aclDestroyTensor(mask_tensor_); + aclDestroyScalar(neg_inf_); + } + + void operator()(const Tensor input, Tensor out) const override { + Tensor temp_t{temp_buf_, input.shape(), input.dtype(), input.device()}; + auto t_in = ascend::buildAclTensor(input); + auto t_temp = ascend::buildAclTensor(temp_t); + auto t_out = ascend::buildAclTensor(out); + auto stream = static_cast(stream_); + + uint64_t ws_needed = 0; + aclOpExecutor* exec = nullptr; + + // Step 1: copy input (possibly non-contiguous) into contiguous temp. + aclnnInplaceCopyGetWorkspaceSize(t_temp, t_in, &ws_needed, &exec); + auto& copy_arena = ascend::workspacePool().ensure(stream, ws_needed); + uint64_t copy_ws = ws_needed; + aclnnInplaceCopy(copy_arena.buf, copy_ws, exec, stream); + + // Step 2: mask upper-triangle positions with -inf in-place. + ws_needed = 0; + exec = nullptr; + aclnnInplaceMaskedFillScalarGetWorkspaceSize(t_temp, mask_tensor_, neg_inf_, + &ws_needed, &exec); + auto& fill_arena = ascend::workspacePool().ensure(stream, ws_needed); + uint64_t fill_ws = ws_needed; + aclnnInplaceMaskedFillScalar(fill_arena.buf, fill_ws, exec, stream); + + // Step 3: softmax over the last dimension → out. + ws_needed = 0; + exec = nullptr; + constexpr int64_t kLastDim = -1; + aclnnSoftmaxGetWorkspaceSize(t_temp, kLastDim, t_out, &ws_needed, &exec); + auto& softmax_arena = ascend::workspacePool().ensure(stream, ws_needed); + uint64_t softmax_ws = ws_needed; + aclnnSoftmax(softmax_arena.buf, softmax_ws, exec, stream); + + aclDestroyTensor(t_in); + aclDestroyTensor(t_temp); + aclDestroyTensor(t_out); + } + + private: + float neg_inf_storage_ = -std::numeric_limits::infinity(); + void* temp_buf_ = nullptr; + void* mask_buf_ = nullptr; + aclTensor* mask_tensor_ = nullptr; + aclScalar* neg_inf_ = nullptr; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/flash_attention/kernel.h b/src/ascend/flash_attention/kernel.h new file mode 100644 index 0000000..3b82e53 --- /dev/null +++ b/src/ascend/flash_attention/kernel.h @@ -0,0 +1,321 @@ +#ifndef INFINI_OPS_ASCEND_FLASH_ATTENTION_KERNEL_H_ +#define INFINI_OPS_ASCEND_FLASH_ATTENTION_KERNEL_H_ + +#include +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_fused_infer_attention_score_v4.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/flash_attention.h" +#include "operator.h" + +namespace infini::ops { + +namespace detail { + +// Build an aclTensor with a different view shape/stride but the same data +// pointer. +inline aclTensor* reshapeView(const Tensor& t, + const std::vector& new_shape, + const std::vector& new_strides) { + int64_t storage_elems = 1; + for (size_t i = 0; i < new_shape.size(); ++i) { + if (new_shape[i] == 0) { + storage_elems = 0; + break; + } + if (new_strides[i] > 0 && new_shape[i] > 1) { + storage_elems += static_cast(new_shape[i] - 1) * new_strides[i]; + } + } + std::vector storage_shape = {storage_elems}; + return aclCreateTensor( + new_shape.data(), static_cast(new_shape.size()), + ascend::toAclDtype(t.dtype()), new_strides.data(), 0, ACL_FORMAT_ND, + storage_shape.data(), static_cast(storage_shape.size()), + const_cast(t.data())); +} + +// Extract cu_seqlens differences to a host aclIntArray. +// cu_seqlens = [0, s1, s1+s2, ...] -> per_seq_lens = [s1, s2, ...]. +// Used by paged decode (actualSeqLengthsKv = per-sequence KV lengths). +inline aclIntArray* extractSeqLengths(const Tensor& cu_seqlens, + aclrtStream stream) { + auto n = cu_seqlens.numel(); + std::vector cu_host(n); + aclrtMemcpyAsync(cu_host.data(), n * sizeof(int64_t), cu_seqlens.data(), + n * sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST, stream); + aclrtSynchronizeStream(stream); + + std::vector lengths(n - 1); + for (size_t i = 0; i < lengths.size(); ++i) { + lengths[i] = cu_host[i + 1] - cu_host[i]; + } + return aclCreateIntArray(lengths.data(), + static_cast(lengths.size())); +} + +// Extract cumulative end positions from cu_seqlens to a host aclIntArray. +// cu_seqlens = [0, s1, s1+s2, ...] -> cum_lens = [s1, s1+s2, ...]. +// FIA V4 TND varlen uses cumulative end positions, matching the vllm-ascend +// convention for npu_fused_infer_attention_score actual_seq_lengths. +inline aclIntArray* cumSeqLengths(const Tensor& cu_seqlens, + aclrtStream stream) { + auto n = cu_seqlens.numel(); + std::vector cu_host(n); + aclrtMemcpyAsync(cu_host.data(), n * sizeof(int64_t), cu_seqlens.data(), + n * sizeof(int64_t), ACL_MEMCPY_DEVICE_TO_HOST, stream); + aclrtSynchronizeStream(stream); + + // Skip the leading 0; return [s1, s1+s2, ...]. + return aclCreateIntArray(cu_host.data() + 1, static_cast(n - 1)); +} + +// Allocate a 2048x2048 lower-triangular UINT8 causal mask on device. +// Required for sparseMode >= 2. +inline aclTensor* makeCausalMask(void** mask_buf, aclrtStream stream) { + constexpr int64_t kMaskDim = 2048; + const int64_t mask_elems = kMaskDim * kMaskDim; + const size_t mask_bytes = static_cast(mask_elems); // uint8_t + + aclrtMalloc(mask_buf, mask_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + std::vector host_mask(mask_elems); + for (int64_t r = 0; r < kMaskDim; ++r) { + for (int64_t c = 0; c < kMaskDim; ++c) { + // 1 = masked out (upper triangle); 0 = attend (lower triangle). + host_mask[r * kMaskDim + c] = (c > r) ? 1 : 0; + } + } + aclrtMemcpyAsync(*mask_buf, mask_bytes, host_mask.data(), mask_bytes, + ACL_MEMCPY_HOST_TO_DEVICE, stream); + aclrtSynchronizeStream(stream); + + std::vector mask_shape = {kMaskDim, kMaskDim}; + std::vector mask_strides = {kMaskDim, 1}; + std::vector mask_storage = {mask_elems}; + return aclCreateTensor(mask_shape.data(), 2, ACL_UINT8, mask_strides.data(), + 0, ACL_FORMAT_ND, mask_storage.data(), 1, *mask_buf); +} + +} // namespace detail + +template <> +class Operator : public FlashAttention { + public: + using FlashAttention::FlashAttention; + + void operator()(const Tensor query, const Tensor key, const Tensor value, + std::optional cu_seqlens_q, + std::optional cu_seqlens_kv, + std::optional block_table, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, double scale, + bool causal, int64_t window_left, int64_t window_right, + int64_t block_size, Tensor output) const override { + auto stream = static_cast(stream_); + const bool paged = block_table.has_value() && block_size > 0; + + // Map causal + window_left/right to FIA sparse_mode / preTokens / + // nextTokens. + // + // causal=true, window_left<0 -> sparse_mode=3 (full causal) + // causal=true, window_left>=0 -> sparse_mode=4 (sliding + // window causal) causal=false -> sparse_mode=0 + // (no mask) + // + // sparse_mode is ignored by FIA when Q_S=1 (paged decode); effective_sparse + // is set to 0 in that path to avoid allocating the unnecessary causal mask. + int64_t sparse_mode; + int64_t pre_tokens = 2147483647; + int64_t next_tokens = 2147483647; + if (causal) { + if (window_left >= 0) { + sparse_mode = 4; // band: sliding window causal + pre_tokens = window_left; + next_tokens = 0; + } else { + sparse_mode = 3; // rightDownCausal: full causal, pre/next ignored + next_tokens = 0; + } + } else { + sparse_mode = 0; + if (window_left >= 0) pre_tokens = window_left; + if (window_right >= 0) next_tokens = window_right; + } + + if (!paged) { + // --- Prefill (single- or multi-sequence) --- + // V4 TND: query/key/value passed as token-packed [T, N, D]; per-sequence + // lengths are derived from cu_seqlens. Single fused call for all + // sequences, equivalent to flash_attn_varlen_func on CUDA. + int64_t T = query.size(0); + + // V4 TND varlen uses cumulative end positions [s1, s1+s2, ...]. + // For single-seq (no cu_seqlens), [T] is both per-seq and cumulative. + aclIntArray* seq_q = + cu_seqlens_q.has_value() + ? detail::cumSeqLengths(cu_seqlens_q.value(), stream) + : aclCreateIntArray(&T, 1); + aclIntArray* seq_kv = + cu_seqlens_kv.has_value() + ? detail::cumSeqLengths(cu_seqlens_kv.value(), stream) + : aclCreateIntArray(&T, 1); + + aclTensor* t_q = ascend::buildAclTensor(query); + aclTensor* t_k = ascend::buildAclTensor(key); + aclTensor* t_v = ascend::buildAclTensor(value); + aclTensor* t_out = ascend::buildAclTensor(output); + + const aclTensor* k_arr[] = {t_k}; + const aclTensor* v_arr[] = {t_v}; + aclTensorList* key_list = aclCreateTensorList(k_arr, 1); + aclTensorList* val_list = aclCreateTensorList(v_arr, 1); + + // sparseMode 2/3/4 require a 2048x2048 lower-triangular causal mask. + aclTensor* atten_mask = nullptr; + void* mask_buf = nullptr; + if (sparse_mode >= 2) { + atten_mask = detail::makeCausalMask(&mask_buf, stream); + } + + uint64_t ws_needed = 0; + aclOpExecutor* executor = nullptr; + // Parameter order: query, key, value, + // pseShift, attenMask, actualSeqLengths, actualSeqLengthsKv, + // deqScale1, quantScale1, deqScale2, quantScale2, quantOffset2, + // antiquantScale, antiquantOffset, + // blockTable, queryPaddingSize, kvPaddingSize, + // keyAntiquantScale, keyAntiquantOffset, + // valueAntiquantScale, valueAntiquantOffset, + // keySharedPrefix, valueSharedPrefix, actualSharedPrefixLen, + // queryRope, keyRope, keyRopeAntiquantScale, + // dequantScaleQuery, learnableSink, + // numHeads, scaleValue, preTokens, nextTokens, inputLayout, + // numKeyValueHeads, sparseMode, innerPrecise, blockSize, + // antiquantMode, softmaxLseFlag, + // keyAntiquantMode, valueAntiquantMode, queryQuantMode, + // attentionOut, softmaxLse, workspaceSize, executor + aclError gws = aclnnFusedInferAttentionScoreV4GetWorkspaceSize( + t_q, key_list, val_list, + nullptr, // pseShift + atten_mask, // attenMask + seq_q, // actualSeqLengths + seq_kv, // actualSeqLengthsKv + nullptr, nullptr, nullptr, nullptr, + nullptr, // deqScale1..quantOffset2 + nullptr, nullptr, // antiquantScale, antiquantOffset + nullptr, // blockTable + nullptr, nullptr, // queryPaddingSize, kvPaddingSize + nullptr, nullptr, nullptr, + nullptr, // key/value antiquant scale/offset + nullptr, nullptr, + nullptr, // keySharedPrefix, valueSharedPrefix, actualSharedPrefixLen + nullptr, nullptr, + nullptr, // queryRope, keyRope, keyRopeAntiquantScale + nullptr, nullptr, // dequantScaleQuery, learnableSink + num_heads, scale, pre_tokens, next_tokens, const_cast("TND"), + num_kv_heads, sparse_mode, + 0, // innerPrecise + 0, // blockSize (unused for prefill) + 0, false, // antiquantMode, softmaxLseFlag + 0, 0, 0, // keyAntiquantMode, valueAntiquantMode, queryQuantMode + t_out, nullptr, &ws_needed, &executor); + assert( + gws == ACL_SUCCESS && + "aclnnFusedInferAttentionScoreV4GetWorkspaceSize failed (prefill)"); + + auto& arena = ascend::workspacePool().ensure(stream, ws_needed); + aclError ret = aclnnFusedInferAttentionScoreV4(arena.buf, ws_needed, + executor, stream); + assert(ret == ACL_SUCCESS && + "aclnnFusedInferAttentionScoreV4 failed (prefill)"); + + aclDestroyTensor(t_q); + aclDestroyTensor(t_out); + aclDestroyTensorList(key_list); + aclDestroyTensorList(val_list); + aclDestroyIntArray(seq_q); + aclDestroyIntArray(seq_kv); + if (atten_mask) aclDestroyTensor(atten_mask); + if (mask_buf) aclrtFree(mask_buf); + return; + } + + // --- Paged decode --- + // V4 BNSD: reshape query/output [B, N, D] -> [B, N, 1, D]. + // KV cache [num_blocks, block_size, N_kv, D] flattened to + // [num_blocks, block_size, N_kv*D] (zero-copy, FIA BSH kv format). + assert(cu_seqlens_kv.has_value() && + "`FlashAttention` paged decode requires `cu_seqlens_kv`"); + + const int64_t N = query.size(1); + const int64_t D = query.size(2); + const int64_t B = query.size(0); + const int64_t nb = key.size(0); + const int64_t bsz = key.size(1); + const int64_t NkvD = key.size(2) * key.size(3); + + std::vector bnsd_sh = {B, N, 1, D}; + std::vector bnsd_st = {N * D, D, D, 1}; + aclTensor* t_query = detail::reshapeView(query, bnsd_sh, bnsd_st); + aclTensor* t_output = detail::reshapeView(output, bnsd_sh, bnsd_st); + + std::vector kv_sh = {nb, bsz, NkvD}; + std::vector kv_st = {bsz * NkvD, NkvD, 1}; + aclTensor* t_key = detail::reshapeView(key, kv_sh, kv_st); + aclTensor* t_value = detail::reshapeView(value, kv_sh, kv_st); + + aclIntArray* seq_kv = + detail::extractSeqLengths(cu_seqlens_kv.value(), stream); + aclTensor* t_block_table = ascend::buildAclTensor(block_table.value()); + + const aclTensor* k_arr[] = {t_key}; + const aclTensor* v_arr[] = {t_value}; + aclTensorList* key_list = aclCreateTensorList(k_arr, 1); + aclTensorList* val_list = aclCreateTensorList(v_arr, 1); + + uint64_t ws_needed = 0; + aclOpExecutor* executor = nullptr; + aclError gws = aclnnFusedInferAttentionScoreV4GetWorkspaceSize( + t_query, key_list, val_list, + nullptr, // pseShift + nullptr, // attenMask (sparseMode ignored for Q_S=1) + nullptr, // actualSeqLengths (ignored for Q_S=1) + seq_kv, // actualSeqLengthsKv (mandatory for paged) + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + t_block_table, // blockTable + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, num_heads, scale, + static_cast(2147483647), static_cast(2147483647), + const_cast("BNSD"), num_kv_heads, + 0, // sparseMode=0 (ignored for Q_S=1) + 0, // innerPrecise + block_size, // blockSize + 0, false, // antiquantMode, softmaxLseFlag + 0, 0, 0, // keyAntiquantMode, valueAntiquantMode, queryQuantMode + t_output, nullptr, &ws_needed, &executor); + assert(gws == ACL_SUCCESS && + "aclnnFusedInferAttentionScoreV4GetWorkspaceSize failed (decode)"); + + auto& arena = ascend::workspacePool().ensure(stream, ws_needed); + aclError ret = + aclnnFusedInferAttentionScoreV4(arena.buf, ws_needed, executor, stream); + assert(ret == ACL_SUCCESS && + "aclnnFusedInferAttentionScoreV4 failed (decode)"); + + aclDestroyTensor(t_query); + aclDestroyTensor(t_output); + aclDestroyTensorList(key_list); + aclDestroyTensorList(val_list); + aclDestroyTensor(t_block_table); + aclDestroyIntArray(seq_kv); + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/matmul/kernel.h b/src/ascend/matmul/kernel.h new file mode 100644 index 0000000..4070634 --- /dev/null +++ b/src/ascend/matmul/kernel.h @@ -0,0 +1,44 @@ +#ifndef INFINI_OPS_ASCEND_MATMUL_KERNEL_H_ +#define INFINI_OPS_ASCEND_MATMUL_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_matmul.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/matmul.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Matmul { + public: + Operator(const Tensor a, const Tensor b, Tensor c, bool trans_a, bool trans_b) + : Matmul(a, b, c, trans_a, trans_b) {} + + void operator()(const Tensor a, const Tensor b, Tensor c, bool trans_a, + bool trans_b) const override { + auto stream = static_cast(stream_); + auto t_a = ascend::buildAclTensor(a, trans_a); + auto t_b = ascend::buildAclTensor(b, trans_b); + auto t_out = ascend::buildAclTensor(c); + + uint64_t ws_needed = 0; + aclOpExecutor* executor = nullptr; + // cube_math_type = 1: allow fp16 accumulation. + int8_t cube_math_type = 1; + aclnnMatmulGetWorkspaceSize(t_a, t_b, t_out, cube_math_type, &ws_needed, + &executor); + auto& arena = ascend::workspacePool().ensure(stream, ws_needed); + aclnnMatmul(arena.buf, ws_needed, executor, stream); + + aclDestroyTensor(t_a); + aclDestroyTensor(t_b); + aclDestroyTensor(t_out); + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/reshape_and_cache/kernel.h b/src/ascend/reshape_and_cache/kernel.h new file mode 100644 index 0000000..609a1ee --- /dev/null +++ b/src/ascend/reshape_and_cache/kernel.h @@ -0,0 +1,71 @@ +#ifndef INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_H_ +#define INFINI_OPS_ASCEND_RESHAPE_AND_CACHE_KERNEL_H_ + +#include +#include +#include + +#include "acl/acl.h" +#include "ascend/device_.h" +#include "base/reshape_and_cache.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator + : public ReshapeAndCache { + public: + using ReshapeAndCache::ReshapeAndCache; + + void operator()(const Tensor key, const Tensor value, const Tensor kv_cache, + const Tensor slot_mapping, + Tensor kv_cache_out) const override { + auto stream = static_cast(stream_); + + // Copy slot_mapping to host for address computation. + auto num_tokens = static_cast(num_tokens_); + std::vector slots(num_tokens); + aclrtMemcpyAsync(slots.data(), num_tokens * sizeof(int64_t), + slot_mapping.data(), num_tokens * sizeof(int64_t), + ACL_MEMCPY_DEVICE_TO_HOST, stream); + aclrtSynchronizeStream(stream); + + auto bs = static_cast(block_size_); + auto row_bytes = static_cast(num_kv_heads_ * head_size_) * + kDataTypeToSize.at(key.dtype()); + + // kv_cache layout: [2, num_blocks, block_size, num_kv_heads, head_size] + // kv_cache[0] = key cache, kv_cache[1] = value cache. + // Stride for the first dim (K vs V): kv_cache.stride(0). + auto kv_stride0 = static_cast(kv_cache_out.stride(0)); + + for (int64_t i = 0; i < num_tokens; ++i) { + auto slot = slots[i]; + if (slot < 0) continue; // Padding token — skip. + auto block_idx = slot / bs; + auto offset = slot % bs; + + auto cache_offset = (block_idx * kv_cache_out.stride(1) + + offset * kv_cache_out.stride(2)) * + kv_cache_out.element_size(); + + auto* k_src = static_cast(key.data()) + + i * key.stride(0) * key.element_size(); + auto* k_dst = static_cast(kv_cache_out.data()) + cache_offset; + aclrtMemcpyAsync(k_dst, row_bytes, k_src, row_bytes, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + + auto* v_src = static_cast(value.data()) + + i * value.stride(0) * value.element_size(); + auto* v_dst = static_cast(kv_cache_out.data()) + + kv_stride0 * kv_cache_out.element_size() + cache_offset; + aclrtMemcpyAsync(v_dst, row_bytes, v_src, row_bytes, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/rms_norm/kernel.h b/src/ascend/rms_norm/kernel.h new file mode 100644 index 0000000..9eef1bb --- /dev/null +++ b/src/ascend/rms_norm/kernel.h @@ -0,0 +1,62 @@ +#ifndef INFINI_OPS_ASCEND_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_ASCEND_RMS_NORM_KERNEL_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_rms_norm.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/rms_norm.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public RmsNorm { + public: + Operator(const Tensor input, const Tensor weight, float eps, Tensor out) + : RmsNorm(input, weight, eps, out) { + // aclnnRmsNorm writes rstd as a required side output. + // Allocate a persistent device buffer for it. + rstd_shape_ = {static_cast(batch_size_), + static_cast(nhead_)}; + size_t rstd_bytes = batch_size_ * nhead_ * sizeof(float); + aclrtMalloc(&rstd_data_, rstd_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + } + + ~Operator() { + if (rstd_data_) aclrtFree(rstd_data_); + } + + void operator()(const Tensor input, const Tensor weight, float eps, + Tensor out) const override { + auto t_in = ascend::buildAclTensor(input); + auto t_weight = ascend::buildAclTensor(weight); + auto t_out = ascend::buildAclTensor(out); + // rstd is always float32 regardless of input dtype. + auto t_rstd = aclCreateTensor(rstd_shape_.data(), 2, ACL_FLOAT, + /*strides=*/nullptr, 0, ACL_FORMAT_ND, + rstd_shape_.data(), 2, rstd_data_); + uint64_t ws_needed = 0; + aclOpExecutor* executor = nullptr; + aclnnRmsNormGetWorkspaceSize(t_in, t_weight, eps, t_out, t_rstd, &ws_needed, + &executor); + auto stream = static_cast(stream_); + auto& arena = ascend::workspacePool().ensure(stream, ws_needed); + aclnnRmsNorm(arena.buf, ws_needed, executor, stream); + aclDestroyTensor(t_in); + aclDestroyTensor(t_weight); + aclDestroyTensor(t_out); + aclDestroyTensor(t_rstd); + } + + private: + std::vector rstd_shape_; + void* rstd_data_ = nullptr; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/rotary_embedding/kernel.h b/src/ascend/rotary_embedding/kernel.h new file mode 100644 index 0000000..5c3da01 --- /dev/null +++ b/src/ascend/rotary_embedding/kernel.h @@ -0,0 +1,505 @@ +#ifndef INFINI_OPS_ASCEND_ROTARY_EMBEDDING_KERNEL_H_ +#define INFINI_OPS_ASCEND_ROTARY_EMBEDDING_KERNEL_H_ + +#include +#include +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_apply_rotary_pos_emb_v2.h" +#include "aclnnop/aclnn_index_select.h" +#include "aclnnop/aclnn_rotary_position_embedding.h" +#include "ascend/data_type_.h" +#include "ascend/workspace_pool_.h" +#include "base/rotary_embedding.h" +#include "operator.h" + +namespace infini::ops { + +// aclnnApplyRotaryPosEmbV2 hardware constraints on Atlas A2/A3: +// - rotaryMode "half" only (neox style) +// - D (last dim of queryRef) must be 64 or 128 +// - bfloat16 only (float16 accumulates with ~1 ULP error that exceeds +// atol=0.001 in tests; bfloat16 passes with atol=0.005) +// +// Use V2 when all three hold; fall back to V1 otherwise. +static bool use_rope_v2(int64_t D, bool is_neox, DataType dtype) { + return is_neox && (D == 64 || D == 128) && dtype == DataType::kBFloat16; +} + +template <> +class Operator + : public RotaryEmbedding { + public: + Operator(const Tensor positions, const Tensor query, const Tensor key, + const Tensor cos_sin_cache, int64_t head_size, int64_t rotary_dim, + bool is_neox_style, Tensor query_out, Tensor key_out) + : RotaryEmbedding(positions, query, key, cos_sin_cache, head_size, + rotary_dim, is_neox_style, query_out, key_out) { + const int64_t max_seq_len = cos_sin_cache.size(0); + const int64_t R = rotary_dim_; + const int64_t half_R = R / 2; + cache_elem_size_ = cos_sin_cache.element_size(); + + // Copy raw cache to host for pre-expansion (one-time cost). + size_t raw_bytes = static_cast(max_seq_len * R) * cache_elem_size_; + std::vector cache_host(raw_bytes); + aclrtMemcpy(cache_host.data(), raw_bytes, cos_sin_cache.data(), raw_bytes, + ACL_MEMCPY_DEVICE_TO_HOST); + + // Pre-expand into separate cos/sin tables with duplicated values. + // After expansion each row is R-wide: + // neox: cos = [c0..c_{hR-1}, c0..c_{hR-1}] (first half repeated) + // interleave: cos = [c0,c0, c1,c1, ..., c_{hR-1},c_{hR-1}] + // Same pattern for sin. + table_bytes_ = raw_bytes; + std::vector cos_table_host(table_bytes_); + std::vector sin_table_host(table_bytes_); + + for (int64_t p = 0; p < max_seq_len; ++p) { + if (is_neox_style_) { + for (int64_t j = 0; j < half_R; ++j) { + const uint8_t* c_src = + cache_host.data() + + static_cast(p * R + j) * cache_elem_size_; + const uint8_t* s_src = + cache_host.data() + + static_cast(p * R + half_R + j) * cache_elem_size_; + auto* cos_dst = cos_table_host.data(); + auto* sin_dst = sin_table_host.data(); + std::memcpy( + cos_dst + static_cast(p * R + j) * cache_elem_size_, + c_src, cache_elem_size_); + std::memcpy(cos_dst + static_cast(p * R + half_R + j) * + cache_elem_size_, + c_src, cache_elem_size_); + std::memcpy( + sin_dst + static_cast(p * R + j) * cache_elem_size_, + s_src, cache_elem_size_); + std::memcpy(sin_dst + static_cast(p * R + half_R + j) * + cache_elem_size_, + s_src, cache_elem_size_); + } + } else { + for (int64_t j = 0; j < half_R; ++j) { + const uint8_t* c_src = + cache_host.data() + + static_cast(p * R + j) * cache_elem_size_; + const uint8_t* s_src = + cache_host.data() + + static_cast(p * R + half_R + j) * cache_elem_size_; + auto* cos_dst = cos_table_host.data(); + auto* sin_dst = sin_table_host.data(); + std::memcpy( + cos_dst + static_cast(p * R + 2 * j) * cache_elem_size_, + c_src, cache_elem_size_); + std::memcpy(cos_dst + static_cast(p * R + 2 * j + 1) * + cache_elem_size_, + c_src, cache_elem_size_); + std::memcpy( + sin_dst + static_cast(p * R + 2 * j) * cache_elem_size_, + s_src, cache_elem_size_); + std::memcpy(sin_dst + static_cast(p * R + 2 * j + 1) * + cache_elem_size_, + s_src, cache_elem_size_); + } + } + } + + // Upload expanded tables to device (one-time). + aclrtMalloc(&cos_table_dev_, table_bytes_, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMalloc(&sin_table_dev_, table_bytes_, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMemcpy(cos_table_dev_, table_bytes_, cos_table_host.data(), + table_bytes_, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(sin_table_dev_, table_bytes_, sin_table_host.data(), + table_bytes_, ACL_MEMCPY_HOST_TO_DEVICE); + + const int64_t T = num_tokens_; + const int64_t Nq = num_heads_; + const int64_t Nkv = num_kv_heads_; + const int64_t D = head_size_; + const bool v2 = use_rope_v2(R, is_neox_style_, query.dtype()); + use_v2_ = v2; + + // Gathered output buffers [T, R] — filled by aclnnIndexSelect at runtime. + gathered_cs_bytes_ = static_cast(T * R) * cache_elem_size_; + aclrtMalloc(&cos_dev_, gathered_cs_bytes_, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMalloc(&sin_dev_, gathered_cs_bytes_, ACL_MEM_MALLOC_NORMAL_ONLY); + + // Scratch for partial-rotation (R < D) — used by both V1 and V2. + if (R < D) { + size_t q_rot_bytes = static_cast(T * Nq * R) * cache_elem_size_; + size_t k_rot_bytes = static_cast(T * Nkv * R) * cache_elem_size_; + aclrtMalloc(&q_rot_dev_, q_rot_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMalloc(&k_rot_dev_, k_rot_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + if (!v2) { + aclrtMalloc(&q_out_rot_dev_, q_rot_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMalloc(&k_out_rot_dev_, k_rot_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + } + } + } + + ~Operator() { + if (cos_table_dev_) aclrtFree(cos_table_dev_); + if (sin_table_dev_) aclrtFree(sin_table_dev_); + if (cos_dev_) aclrtFree(cos_dev_); + if (sin_dev_) aclrtFree(sin_dev_); + if (q_rot_dev_) aclrtFree(q_rot_dev_); + if (k_rot_dev_) aclrtFree(k_rot_dev_); + if (q_out_rot_dev_) aclrtFree(q_out_rot_dev_); + if (k_out_rot_dev_) aclrtFree(k_out_rot_dev_); + } + + void operator()(const Tensor positions, const Tensor query, const Tensor key, + const Tensor cos_sin_cache, int64_t head_size, + int64_t rotary_dim, bool is_neox_style, Tensor query_out, + Tensor key_out) const override { + auto stream = static_cast(stream_); + + const int64_t T = query.size(0); + const int64_t Nq = query.size(1); + const int64_t Nkv = key.size(1); + const int64_t D = head_size; + const int64_t R = rotary_dim; + const int64_t max_seq_len = cos_sin_cache.size(0); + + assert(R <= D); + assert(cos_sin_cache.size(1) == R); + + // 1. Gather cos/sin on device via aclnnIndexSelect — fully async. + // No host sync, no D2H copy. Positions stay on device. + { + aclDataType acl_dt_cs = ascend::toAclDtype(query.dtype()); + + // Table tensors: [max_seq_len, R] + std::vector table_shape = {max_seq_len, R}; + std::vector table_strides = {R, 1}; + std::vector table_storage = {max_seq_len * R}; + + aclTensor* t_cos_table = aclCreateTensor( + table_shape.data(), 2, acl_dt_cs, table_strides.data(), 0, + ACL_FORMAT_ND, table_storage.data(), 1, cos_table_dev_); + aclTensor* t_sin_table = aclCreateTensor( + table_shape.data(), 2, acl_dt_cs, table_strides.data(), 0, + ACL_FORMAT_ND, table_storage.data(), 1, sin_table_dev_); + + // Index tensor: positions [T], int64 — stays on device. + std::vector idx_shape = {T}; + std::vector idx_strides = {1}; + std::vector idx_storage = {T}; + aclTensor* t_idx = aclCreateTensor( + idx_shape.data(), 1, ACL_INT64, idx_strides.data(), 0, ACL_FORMAT_ND, + idx_storage.data(), 1, const_cast(positions.data())); + + // Output tensors: [T, R] + std::vector out_shape = {T, R}; + std::vector out_strides = {R, 1}; + std::vector out_storage = {T * R}; + + aclTensor* t_cos_out = + aclCreateTensor(out_shape.data(), 2, acl_dt_cs, out_strides.data(), 0, + ACL_FORMAT_ND, out_storage.data(), 1, cos_dev_); + aclTensor* t_sin_out = + aclCreateTensor(out_shape.data(), 2, acl_dt_cs, out_strides.data(), 0, + ACL_FORMAT_ND, out_storage.data(), 1, sin_dev_); + + // Get workspace sizes and executors for both gathers. + uint64_t ws_cos = 0, ws_sin = 0; + aclOpExecutor *exec_cos = nullptr, *exec_sin = nullptr; + aclnnIndexSelectGetWorkspaceSize(t_cos_table, 0, t_idx, t_cos_out, + &ws_cos, &exec_cos); + aclnnIndexSelectGetWorkspaceSize(t_sin_table, 0, t_idx, t_sin_out, + &ws_sin, &exec_sin); + + // Single workspace buffer large enough for both calls. + uint64_t ws_max = ws_cos > ws_sin ? ws_cos : ws_sin; + auto& arena = ascend::workspacePool().ensure(stream, ws_max); + + aclnnIndexSelect(arena.buf, ws_cos, exec_cos, stream); + aclnnIndexSelect(arena.buf, ws_sin, exec_sin, stream); + + aclDestroyTensor(t_cos_table); + aclDestroyTensor(t_sin_table); + aclDestroyTensor(t_idx); + aclDestroyTensor(t_cos_out); + aclDestroyTensor(t_sin_out); + } + + aclDataType acl_dt = ascend::toAclDtype(query.dtype()); + + if (use_v2_) { + // V2: fused Q+K, in-place, layout=4 (T-first 3D), "half" mode. + // cos/sin shape: [T, 1, R]. + std::vector cs_shape = {T, 1, R}; + std::vector cs_strides = {R, R, 1}; + std::vector cs_storage = {T * R}; + aclTensor* t_cos = + aclCreateTensor(cs_shape.data(), 3, acl_dt, cs_strides.data(), 0, + ACL_FORMAT_ND, cs_storage.data(), 1, cos_dev_); + aclTensor* t_sin = + aclCreateTensor(cs_shape.data(), 3, acl_dt, cs_strides.data(), 0, + ACL_FORMAT_ND, cs_storage.data(), 1, sin_dev_); + + int64_t layout = 4; + if (R == D) { + apply_rope_v2_full(query, key, query_out, key_out, T, Nq, Nkv, D, + acl_dt, t_cos, t_sin, layout, stream); + } else { + apply_rope_v2_partial(query, key, query_out, key_out, T, Nq, Nkv, D, R, + acl_dt, t_cos, t_sin, layout, stream); + } + aclDestroyTensor(t_cos); + aclDestroyTensor(t_sin); + } else { + // V1: separate Q and K calls, non-in-place, [1,T,1,R] cos/sin. + std::vector cs_shape = {1, T, 1, R}; + std::vector cs_strides = {T * R, R, R, 1}; + std::vector cs_storage = {T * R}; + aclTensor* t_cos = + aclCreateTensor(cs_shape.data(), 4, acl_dt, cs_strides.data(), 0, + ACL_FORMAT_ND, cs_storage.data(), 1, cos_dev_); + aclTensor* t_sin = + aclCreateTensor(cs_shape.data(), 4, acl_dt, cs_strides.data(), 0, + ACL_FORMAT_ND, cs_storage.data(), 1, sin_dev_); + + int64_t mode = is_neox_style ? 0 : 1; + apply_rope_v1(query, query_out, T, Nq, D, R, mode, t_cos, t_sin, + q_rot_dev_, q_out_rot_dev_, stream); + apply_rope_v1(key, key_out, T, Nkv, D, R, mode, t_cos, t_sin, k_rot_dev_, + k_out_rot_dev_, stream); + + aclDestroyTensor(t_cos); + aclDestroyTensor(t_sin); + } + } + + private: + size_t cache_elem_size_ = 1; + + // Pre-expanded cos/sin tables on device: [max_seq_len, R]. + // Built once in the constructor with neox/interleave duplication. + void* cos_table_dev_ = nullptr; + void* sin_table_dev_ = nullptr; + size_t table_bytes_ = 0; + + // true when V2 hardware constraints are met (neox, D∈{64,128}, bf16). + bool use_v2_ = false; + + // Device buffers for gathered [T, R] cos/sin (shared by V1 and V2). + void* cos_dev_ = nullptr; + void* sin_dev_ = nullptr; + size_t gathered_cs_bytes_ = 0; + + // Scratch for partial rotation (R < D). + void* q_rot_dev_ = nullptr; + void* k_rot_dev_ = nullptr; + void* q_out_rot_dev_ = nullptr; + void* k_out_rot_dev_ = nullptr; + + // --- V2 helpers (neox bf16, D∈{64,128}) --- + + void apply_rope_v2_full(const Tensor& q, const Tensor& k, Tensor& q_out, + Tensor& k_out, int64_t T, int64_t Nq, int64_t Nkv, + int64_t D, aclDataType acl_dt, aclTensor* t_cos, + aclTensor* t_sin, int64_t layout, + aclrtStream stream) const { + size_t elem_sz = q.element_size(); + if (q.data() != q_out.data()) { + aclrtMemcpyAsync(const_cast(q_out.data()), + static_cast(T * Nq * D) * elem_sz, q.data(), + static_cast(T * Nq * D) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + if (k.data() != k_out.data()) { + size_t k_elem_sz = k.element_size(); + aclrtMemcpyAsync(const_cast(k_out.data()), + static_cast(T * Nkv * D) * k_elem_sz, k.data(), + static_cast(T * Nkv * D) * k_elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + std::vector q_shape = {T, Nq, D}; + std::vector q_strides = {Nq * D, D, 1}; + std::vector q_storage = {T * Nq * D}; + std::vector k_shape = {T, Nkv, D}; + std::vector k_strides = {Nkv * D, D, 1}; + std::vector k_storage = {T * Nkv * D}; + aclTensor* t_q = aclCreateTensor( + q_shape.data(), 3, acl_dt, q_strides.data(), 0, ACL_FORMAT_ND, + q_storage.data(), 1, const_cast(q_out.data())); + aclTensor* t_k = aclCreateTensor( + k_shape.data(), 3, acl_dt, k_strides.data(), 0, ACL_FORMAT_ND, + k_storage.data(), 1, const_cast(k_out.data())); + uint64_t ws = 0; + aclOpExecutor* exec = nullptr; + aclnnApplyRotaryPosEmbV2GetWorkspaceSize( + t_q, t_k, t_cos, t_sin, layout, const_cast("half"), &ws, &exec); + auto& arena = ascend::workspacePool().ensure(stream, ws); + aclnnApplyRotaryPosEmbV2(arena.buf, ws, exec, stream); + aclDestroyTensor(t_q); + aclDestroyTensor(t_k); + } + + void apply_rope_v2_partial(const Tensor& q, const Tensor& k, Tensor& q_out, + Tensor& k_out, int64_t T, int64_t Nq, int64_t Nkv, + int64_t D, int64_t R, aclDataType acl_dt, + aclTensor* t_cos, aclTensor* t_sin, int64_t layout, + aclrtStream stream) const { + size_t elem_sz = q.element_size(); + size_t k_elem_sz = k.element_size(); + const int64_t pass = D - R; + + for (int64_t i = 0; i < T * Nq; ++i) { + aclrtMemcpyAsync(static_cast(q_rot_dev_) + + static_cast(i * R) * elem_sz, + static_cast(R) * elem_sz, + static_cast(q.data()) + + static_cast(i * D) * elem_sz, + static_cast(R) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + for (int64_t i = 0; i < T * Nkv; ++i) { + aclrtMemcpyAsync(static_cast(k_rot_dev_) + + static_cast(i * R) * k_elem_sz, + static_cast(R) * k_elem_sz, + static_cast(k.data()) + + static_cast(i * D) * k_elem_sz, + static_cast(R) * k_elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + std::vector qr_shape = {T, Nq, R}; + std::vector qr_strides = {Nq * R, R, 1}; + std::vector qr_storage = {T * Nq * R}; + std::vector kr_shape = {T, Nkv, R}; + std::vector kr_strides = {Nkv * R, R, 1}; + std::vector kr_storage = {T * Nkv * R}; + aclTensor* t_q_rot = + aclCreateTensor(qr_shape.data(), 3, acl_dt, qr_strides.data(), 0, + ACL_FORMAT_ND, qr_storage.data(), 1, q_rot_dev_); + aclTensor* t_k_rot = + aclCreateTensor(kr_shape.data(), 3, acl_dt, kr_strides.data(), 0, + ACL_FORMAT_ND, kr_storage.data(), 1, k_rot_dev_); + uint64_t ws = 0; + aclOpExecutor* exec = nullptr; + aclnnApplyRotaryPosEmbV2GetWorkspaceSize(t_q_rot, t_k_rot, t_cos, t_sin, + layout, const_cast("half"), + &ws, &exec); + auto& arena = ascend::workspacePool().ensure(stream, ws); + aclnnApplyRotaryPosEmbV2(arena.buf, ws, exec, stream); + aclDestroyTensor(t_q_rot); + aclDestroyTensor(t_k_rot); + + for (int64_t i = 0; i < T * Nq; ++i) { + aclrtMemcpyAsync(static_cast(const_cast(q_out.data())) + + static_cast(i * D) * elem_sz, + static_cast(R) * elem_sz, + static_cast(q_rot_dev_) + + static_cast(i * R) * elem_sz, + static_cast(R) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + aclrtMemcpyAsync(static_cast(const_cast(q_out.data())) + + static_cast(i * D + R) * elem_sz, + static_cast(pass) * elem_sz, + static_cast(q.data()) + + static_cast(i * D + R) * elem_sz, + static_cast(pass) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + for (int64_t i = 0; i < T * Nkv; ++i) { + aclrtMemcpyAsync(static_cast(const_cast(k_out.data())) + + static_cast(i * D) * k_elem_sz, + static_cast(R) * k_elem_sz, + static_cast(k_rot_dev_) + + static_cast(i * R) * k_elem_sz, + static_cast(R) * k_elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + aclrtMemcpyAsync(static_cast(const_cast(k_out.data())) + + static_cast(i * D + R) * k_elem_sz, + static_cast(pass) * k_elem_sz, + static_cast(k.data()) + + static_cast(i * D + R) * k_elem_sz, + static_cast(pass) * k_elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + } + + // --- V1 helper (fallback for non-neox, fp16, or D not in {64,128}) --- + + void apply_rope_v1(const Tensor& x, Tensor& out, int64_t T, int64_t N, + int64_t D, int64_t R, int64_t mode, aclTensor* t_cos, + aclTensor* t_sin, void* x_rot_dev, void* out_rot_dev, + aclrtStream stream) const { + aclDataType acl_dt = ascend::toAclDtype(x.dtype()); + size_t elem_sz = x.element_size(); + + if (R < D) { + for (int64_t i = 0; i < T * N; ++i) { + aclrtMemcpyAsync(static_cast(x_rot_dev) + + static_cast(i * R) * elem_sz, + static_cast(R) * elem_sz, + static_cast(x.data()) + + static_cast(i * D) * elem_sz, + static_cast(R) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + std::vector rot_sh = {1, T, N, R}; + std::vector rot_st = {T * N * R, N * R, R, 1}; + std::vector rot_storage = {T * N * R}; + aclTensor* t_x_rot = + aclCreateTensor(rot_sh.data(), 4, acl_dt, rot_st.data(), 0, + ACL_FORMAT_ND, rot_storage.data(), 1, x_rot_dev); + aclTensor* t_out_rot = + aclCreateTensor(rot_sh.data(), 4, acl_dt, rot_st.data(), 0, + ACL_FORMAT_ND, rot_storage.data(), 1, out_rot_dev); + uint64_t ws = 0; + aclOpExecutor* exec = nullptr; + aclnnRotaryPositionEmbeddingGetWorkspaceSize(t_x_rot, t_cos, t_sin, mode, + t_out_rot, &ws, &exec); + auto& arena = ascend::workspacePool().ensure(stream, ws); + aclnnRotaryPositionEmbedding(arena.buf, ws, exec, stream); + + const int64_t pass = D - R; + for (int64_t i = 0; i < T * N; ++i) { + aclrtMemcpyAsync(static_cast(const_cast(out.data())) + + static_cast(i * D) * elem_sz, + static_cast(R) * elem_sz, + static_cast(out_rot_dev) + + static_cast(i * R) * elem_sz, + static_cast(R) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + aclrtMemcpyAsync(static_cast(const_cast(out.data())) + + static_cast(i * D + R) * elem_sz, + static_cast(pass) * elem_sz, + static_cast(x.data()) + + static_cast(i * D + R) * elem_sz, + static_cast(pass) * elem_sz, + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + } + aclDestroyTensor(t_x_rot); + aclDestroyTensor(t_out_rot); + } else { + std::vector full_sh = {1, T, N, D}; + std::vector full_st = {T * N * D, N * D, D, 1}; + std::vector full_storage = {T * N * D}; + aclTensor* t_x = aclCreateTensor( + full_sh.data(), 4, acl_dt, full_st.data(), 0, ACL_FORMAT_ND, + full_storage.data(), 1, const_cast(x.data())); + aclTensor* t_out = aclCreateTensor( + full_sh.data(), 4, acl_dt, full_st.data(), 0, ACL_FORMAT_ND, + full_storage.data(), 1, const_cast(out.data())); + uint64_t ws = 0; + aclOpExecutor* exec = nullptr; + aclnnRotaryPositionEmbeddingGetWorkspaceSize(t_x, t_cos, t_sin, mode, + t_out, &ws, &exec); + auto& arena = ascend::workspacePool().ensure(stream, ws); + aclnnRotaryPositionEmbedding(arena.buf, ws, exec, stream); + aclDestroyTensor(t_x); + aclDestroyTensor(t_out); + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/swiglu/kernel.h b/src/ascend/swiglu/kernel.h new file mode 100644 index 0000000..c7d31e7 --- /dev/null +++ b/src/ascend/swiglu/kernel.h @@ -0,0 +1,70 @@ +#ifndef INFINI_OPS_ASCEND_SWIGLU_KERNEL_H_ +#define INFINI_OPS_ASCEND_SWIGLU_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_mul.h" +#include "aclnn_silu.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/swiglu.h" +#include "data_type.h" +#include "operator.h" + +namespace infini::ops { + +// Implements SwiGLU as two ACLNN calls: silu(gate) into a temp buffer, +// then elementwise mul(input, temp) into out. +// aclnnSiluMul was not used because it fuses silu_AND_mul on the same +// tensor (x * silu(x)), whereas SwiGLU requires input * silu(gate) — +// two distinct inputs. +template <> +class Operator : public Swiglu { + public: + Operator(const Tensor input, const Tensor gate, Tensor out) + : Swiglu(input, gate, out) { + size_t nbytes = input.numel() * kDataTypeToSize.at(input.dtype()); + aclrtMalloc(&temp_buf_, nbytes, ACL_MEM_MALLOC_NORMAL_ONLY); + } + + ~Operator() { aclrtFree(temp_buf_); } + + void operator()(const Tensor input, const Tensor gate, + Tensor out) const override { + // temp_buf_ is a contiguous scratch buffer; give it contiguous strides. + Tensor temp_t{temp_buf_, gate.shape(), gate.dtype(), gate.device()}; + + auto t_in = ascend::buildAclTensor(input); + auto t_gate = ascend::buildAclTensor(gate); + auto t_out = ascend::buildAclTensor(out); + auto t_temp = ascend::buildAclTensor(temp_t); + + uint64_t ws_needed = 0; + aclOpExecutor* exec = nullptr; + auto stream = static_cast(stream_); + + // Step 1: silu(gate) -> temp. SwiGLU = input * silu(gate). + aclnnSiluGetWorkspaceSize(t_gate, t_temp, &ws_needed, &exec); + auto& silu_arena = ascend::workspacePool().ensure(stream, ws_needed); + aclnnSilu(silu_arena.buf, ws_needed, exec, stream); + + // Step 2: mul(input, temp) -> out. + uint64_t mul_ws = 0; + exec = nullptr; + aclnnMulGetWorkspaceSize(t_in, t_temp, t_out, &mul_ws, &exec); + auto& mul_arena = ascend::workspacePool().ensure(stream, mul_ws); + aclnnMul(mul_arena.buf, mul_ws, exec, stream); + + aclDestroyTensor(t_in); + aclDestroyTensor(t_gate); + aclDestroyTensor(t_out); + aclDestroyTensor(t_temp); + } + + private: + void* temp_buf_ = nullptr; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/rotary_embedding.h b/src/base/rotary_embedding.h index a38b20e..70989fa 100644 --- a/src/base/rotary_embedding.h +++ b/src/base/rotary_embedding.h @@ -15,8 +15,8 @@ class RotaryEmbedding : public Operator { int64_t rotary_dim, bool is_neox_style, Tensor query_out, Tensor key_out) : num_tokens_{query.size(0)}, - num_heads_{query.size(1)}, - num_kv_heads_{key.size(1)}, + num_heads_{static_cast(query.size(1))}, + num_kv_heads_{static_cast(key.size(1))}, head_size_{head_size}, rotary_dim_{rotary_dim}, is_neox_style_{is_neox_style}, diff --git a/tests/test_add.py b/tests/test_add.py index 8b8166c..f560435 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -2,7 +2,13 @@ import pytest import torch -from tests.utils import Payload, empty_strided, randint_strided, randn_strided +from tests.utils import ( + Payload, + empty_strided, + get_npu_stream, + randint_strided, + randn_strided, +) _INT_DTYPES = (torch.int16, torch.int32, torch.int64) @@ -63,7 +69,10 @@ def test_add( def _add(input, other, out): - infini.ops.add(input, other, out) + if input.device.type == "npu": + infini.ops.add(input, other, out, stream=get_npu_stream(input)) + else: + infini.ops.add(input, other, out) return out diff --git a/tests/test_causal_softmax.py b/tests/test_causal_softmax.py index 8b35457..df4894c 100644 --- a/tests/test_causal_softmax.py +++ b/tests/test_causal_softmax.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, empty_strided, randn_strided +from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided @pytest.mark.auto_act_and_assert @@ -40,7 +40,10 @@ def test_causal_softmax(shape, input_strides, out_strides, dtype, device, rtol, def _causal_softmax(input, out): - infini.ops.causal_softmax(input, out) + if input.device.type == "npu": + infini.ops.causal_softmax(input, out, stream=get_npu_stream(input)) + else: + infini.ops.causal_softmax(input, out) return out @@ -48,7 +51,7 @@ def _causal_softmax(input, out): def _torch_causal_softmax(input, out): mask = torch.tril(torch.ones_like(input), diagonal=-1).flip(dims=[-2, -1]) masked = torch.where(mask == 1, -torch.inf, input.to(torch.float32)) - result = torch.nn.functional.softmax(masked, dim=-1, dtype=input.dtype) + result = torch.nn.functional.softmax(masked, dim=-1) out.copy_(result) return out diff --git a/tests/test_e2e_layer.py b/tests/test_e2e_layer.py new file mode 100644 index 0000000..92df9a2 --- /dev/null +++ b/tests/test_e2e_layer.py @@ -0,0 +1,418 @@ +import infini.ops +import pytest +import torch + +from tests.utils import get_npu_stream, randn_strided, randint_strided + + +def _stream_kw(tensor): + if tensor.device.type == "npu": + return {"stream": get_npu_stream(tensor)} + + return {} + + +def _ref_rms_norm(x, weight, eps): + rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + eps) + + return (x / rms) * weight + + +def _ref_rope( + positions, query, key, cos_sin_cache, head_size, rotary_dim, is_neox_style +): + T = query.size(0) + R = rotary_dim + half_R = R // 2 + cos_half = cos_sin_cache[:, :half_R] + sin_half = cos_sin_cache[:, half_R:] + + def apply_rope(x): + out = x.clone() + + for t in range(T): + p = positions[t].item() + c = cos_half[p] + s = sin_half[p] + + if is_neox_style: + x1 = x[t, :, :half_R] + x2 = x[t, :, half_R:R] + out[t, :, :half_R] = c * x1 - s * x2 + out[t, :, half_R:R] = c * x2 + s * x1 + else: + x1 = x[t, :, 0::2] + x2 = x[t, :, 1::2] + out[t, :, 0::2] = c * x1 - s * x2 + out[t, :, 1::2] = c * x2 + s * x1 + + return out + + return apply_rope(query), apply_rope(key) + + +def _ref_sdpa(query, key, value, num_heads, num_kv_heads, head_size, scale, causal): + q = query.transpose(0, 1).float() + k = key.transpose(0, 1).float() + v = value.transpose(0, 1).float() + + if num_kv_heads < num_heads: + ratio = num_heads // num_kv_heads + k = k.repeat_interleave(ratio, dim=0) + v = v.repeat_interleave(ratio, dim=0) + + out = torch.nn.functional.scaled_dot_product_attention( + q.unsqueeze(0), + k.unsqueeze(0), + v.unsqueeze(0), + scale=scale, + is_causal=causal, + ) + + return out.squeeze(0).transpose(0, 1) + + +def _infiniops_layer( + hidden, + positions, + cos_sin_cache, + input_norm_w, + qkv_proj_w, + o_proj_w, + gate_proj_w, + up_proj_w, + down_proj_w, + post_norm_w, + num_heads, + num_kv_heads, + head_size, + rotary_dim, + intermediate_size, + is_neox_style, + eps, + scale, + num_tokens, +): + """Run one LLaMA decoder layer using InfiniOps kernels.""" + kw = _stream_kw(hidden) + dtype = hidden.dtype + device = hidden.device + hidden_size = hidden.size(-1) + + # Save residual. + residual = hidden.clone() + + # 1. Input RMSNorm. + normed = torch.empty_like(hidden) + infini.ops.rms_norm(hidden, input_norm_w, eps, normed, **kw) + + # 2. QKV projection: [T, D] @ [D, (N+2*Nkv)*H] -> [T, (N+2*Nkv)*H]. + qkv_dim = (num_heads + 2 * num_kv_heads) * head_size + qkv = torch.empty(num_tokens, qkv_dim, dtype=dtype, device=device) + infini.ops.gemm(normed, qkv_proj_w, 1.0, 0.0, False, False, qkv, **kw) + + # Split Q, K, V. + q = ( + qkv[:, : num_heads * head_size] + .reshape( + num_tokens, + num_heads, + head_size, + ) + .contiguous() + ) + k = ( + qkv[:, num_heads * head_size : (num_heads + num_kv_heads) * head_size] + .reshape( + num_tokens, + num_kv_heads, + head_size, + ) + .contiguous() + ) + v = ( + qkv[:, (num_heads + num_kv_heads) * head_size :] + .reshape( + num_tokens, + num_kv_heads, + head_size, + ) + .contiguous() + ) + + # 3. RoPE. + q_rot = torch.empty_like(q) + k_rot = torch.empty_like(k) + infini.ops.rotary_embedding( + positions, + q, + k, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + q_rot, + k_rot, + **kw, + ) + + # 4. Flash attention (single-sequence prefill, causal). + attn_out = torch.empty( + num_tokens, + num_heads, + head_size, + dtype=dtype, + device=device, + ) + infini.ops.flash_attention( + q_rot, + k_rot, + v, + None, + None, + None, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + 0, + attn_out, + **kw, + ) + + # 5. O projection: [T, N*H] @ [N*H, D] -> [T, D]. + attn_2d = attn_out.reshape(num_tokens, num_heads * head_size) + o_out = torch.empty(num_tokens, hidden_size, dtype=dtype, device=device) + infini.ops.gemm(attn_2d, o_proj_w, 1.0, 0.0, False, False, o_out, **kw) + + # 6. Residual add. + after_attn = torch.empty_like(residual) + infini.ops.add(residual, o_out, after_attn, **kw) + + # 7. Post-attention RMSNorm. + residual2 = after_attn.clone() + normed2 = torch.empty_like(after_attn) + infini.ops.rms_norm(after_attn, post_norm_w, eps, normed2, **kw) + + # 8. Gate + up projections. + gate = torch.empty(num_tokens, intermediate_size, dtype=dtype, device=device) + up = torch.empty(num_tokens, intermediate_size, dtype=dtype, device=device) + infini.ops.gemm(normed2, gate_proj_w, 1.0, 0.0, False, False, gate, **kw) + infini.ops.gemm(normed2, up_proj_w, 1.0, 0.0, False, False, up, **kw) + + # 9. SwiGLU: ``up * silu(gate)``. + ffn = torch.empty(num_tokens, intermediate_size, dtype=dtype, device=device) + infini.ops.swiglu(up, gate, ffn, **kw) + + # 10. Down projection: [T, FFN] @ [FFN, D] -> [T, D]. + down = torch.empty(num_tokens, hidden_size, dtype=dtype, device=device) + infini.ops.gemm(ffn, down_proj_w, 1.0, 0.0, False, False, down, **kw) + + # 11. Second residual add. + output = torch.empty_like(residual2) + infini.ops.add(residual2, down, output, **kw) + + return output + + +def _reference_layer( + hidden, + positions, + cos_sin_cache, + input_norm_w, + qkv_proj_w, + o_proj_w, + gate_proj_w, + up_proj_w, + down_proj_w, + post_norm_w, + num_heads, + num_kv_heads, + head_size, + rotary_dim, + intermediate_size, + is_neox_style, + eps, + scale, + num_tokens, +): + """PyTorch float32 reference for one LLaMA decoder layer.""" + # Compute in float32 on CPU for accuracy. + h = hidden.float().cpu() + pos = positions.cpu() + csc = cos_sin_cache.float().cpu() + inw = input_norm_w.float().cpu() + qkvw = qkv_proj_w.float().cpu() + ow = o_proj_w.float().cpu() + gw = gate_proj_w.float().cpu() + uw = up_proj_w.float().cpu() + dw = down_proj_w.float().cpu() + pnw = post_norm_w.float().cpu() + + # 1. Input RMSNorm. + residual = h.clone() + normed = _ref_rms_norm(h, inw, eps) + + # 2. QKV projection. + qkv = normed @ qkvw + + q = qkv[:, : num_heads * head_size].reshape(num_tokens, num_heads, head_size) + k = qkv[:, num_heads * head_size : (num_heads + num_kv_heads) * head_size].reshape( + num_tokens, + num_kv_heads, + head_size, + ) + v = qkv[:, (num_heads + num_kv_heads) * head_size :].reshape( + num_tokens, + num_kv_heads, + head_size, + ) + + # 3. RoPE. + q_rot, k_rot = _ref_rope( + pos, + q, + k, + csc, + head_size, + rotary_dim, + is_neox_style, + ) + + # 4. SDPA. + attn_out = _ref_sdpa( + q_rot, k_rot, v, num_heads, num_kv_heads, head_size, scale, causal=True + ) + + # 5. O projection. + attn_2d = attn_out.reshape(num_tokens, num_heads * head_size) + o_out = attn_2d @ ow + + # 6. Residual add. + after_attn = residual + o_out + + # 7. Post-attention RMSNorm. + residual2 = after_attn.clone() + normed2 = _ref_rms_norm(after_attn, pnw, eps) + + # 8. Gate + up projections. + gate = normed2 @ gw + up = normed2 @ uw + + # 9. SwiGLU: ``up * silu(gate)``. + ffn = up * (gate * torch.sigmoid(gate)) + + # 10. Down projection. + down = ffn @ dw + + # 11. Second residual add. + output = residual2 + down + + return output.to(hidden.dtype).to(hidden.device) + + +def _make_rope_cache(max_seq_len, rotary_dim, dtype, device): + """Build a proper RoPE cos/sin cache (bounded to [-1, 1]).""" + freq = 1.0 / (10000.0 ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim)) + t = torch.arange(max_seq_len, dtype=torch.float32) + angles = torch.outer(t, freq) # [max_seq_len, half_dim] + cos_half = torch.cos(angles).to(dtype=dtype, device=device) + sin_half = torch.sin(angles).to(dtype=dtype, device=device) + + return torch.cat([cos_half, sin_half], dim=-1) + + +@pytest.mark.parametrize("device", ("npu",)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 5e-3, 5e-3), + (torch.bfloat16, 1e-2, 2e-2), + ), +) +def test_llama_layer(device, dtype, rtol, atol): + """End-to-end test of a LLaMA decoder layer using InfiniOps kernels.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + # Small LLaMA-like model config. + hidden_size = 512 + num_heads = 8 + num_kv_heads = 2 + head_size = hidden_size // num_heads + intermediate_size = 1024 + num_tokens = 1 + max_seq_len = 16 + rotary_dim = head_size + is_neox_style = True + eps = 1e-6 + scale = 1.0 / head_size**0.5 + + def _scaled_weight(*shape): + return randn_strided(shape, None, dtype=dtype, device=device) / shape[0] ** 0.5 + + # Random weights (stored as [in_features, out_features], Xavier-scaled). + qkv_proj_w = _scaled_weight( + hidden_size, + (num_heads + 2 * num_kv_heads) * head_size, + ) + o_proj_w = _scaled_weight(num_heads * head_size, hidden_size) + gate_proj_w = _scaled_weight(hidden_size, intermediate_size) + up_proj_w = _scaled_weight(hidden_size, intermediate_size) + down_proj_w = _scaled_weight(intermediate_size, hidden_size) + input_norm_w = torch.ones(hidden_size, dtype=dtype, device=device) + post_norm_w = torch.ones(hidden_size, dtype=dtype, device=device) + + # Proper cos/sin cache from frequency decomposition (bounded [-1, 1]). + cos_sin_cache = _make_rope_cache(max_seq_len, rotary_dim, dtype, device) + positions = randint_strided( + 0, + max_seq_len, + (num_tokens,), + None, + dtype=torch.int64, + device=device, + ) + + # Input hidden states scaled to prevent value explosion through layers. + hidden = ( + randn_strided( + (num_tokens, hidden_size), + None, + dtype=dtype, + device=device, + ) + / hidden_size**0.5 + ) + + common = dict( + positions=positions, + cos_sin_cache=cos_sin_cache, + input_norm_w=input_norm_w, + qkv_proj_w=qkv_proj_w, + o_proj_w=o_proj_w, + gate_proj_w=gate_proj_w, + up_proj_w=up_proj_w, + down_proj_w=down_proj_w, + post_norm_w=post_norm_w, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + rotary_dim=rotary_dim, + intermediate_size=intermediate_size, + is_neox_style=is_neox_style, + eps=eps, + scale=scale, + num_tokens=num_tokens, + ) + + infini_out = _infiniops_layer(hidden, **common) + ref_out = _reference_layer(hidden, **common) + + max_diff = (infini_out.float() - ref_out.float()).abs().max().item() + assert torch.allclose(infini_out, ref_out, rtol=rtol, atol=atol), ( + f"Max diff: {max_diff}" + ) diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py new file mode 100644 index 0000000..4b8be3f --- /dev/null +++ b/tests/test_flash_attention.py @@ -0,0 +1,442 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, get_npu_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size", + ( + (32, 32, 128), # MHA + (32, 8, 128), # GQA (4x) + (16, 4, 64), # GQA (4x), smaller + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_flash_attention_prefill_single( + num_heads, + num_kv_heads, + head_size, + dtype, + rtol, + atol, + device, +): + """Single sequence prefill (no block table).""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_tokens = 16 + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_tokens, num_heads, head_size), None, dtype=dtype, device=device + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + output = torch.empty((num_tokens, num_heads, head_size), dtype=dtype, device=device) + + return Payload( + lambda q, k, v, o: _flash_attention( + q, + k, + v, + None, + None, + None, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + 0, + o, + ), + lambda q, k, v, o: _ref_flash_attention( + q, + k, + v, + num_heads, + num_kv_heads, + head_size, + scale, + causal=True, + ), + (query, key, value, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size", + ((32, 8, 128),), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_flash_attention_prefill_multi( + num_heads, + num_kv_heads, + head_size, + dtype, + rtol, + atol, + device, +): + """Multi-sequence prefill with cu_seqlens.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + seq_lens = [8, 12, 4] + num_tokens = sum(seq_lens) + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_tokens, num_heads, head_size), None, dtype=dtype, device=device + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + output = torch.empty((num_tokens, num_heads, head_size), dtype=dtype, device=device) + + cu_seqlens_q = torch.tensor( + [0] + [sum(seq_lens[: i + 1]) for i in range(len(seq_lens))], + dtype=torch.int64, + device=device, + ) + cu_seqlens_kv = cu_seqlens_q.clone() + + return Payload( + lambda q, k, v, o: _flash_attention( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + None, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + 0, + o, + ), + lambda q, k, v, o: _ref_flash_attention_multi( + q, + k, + v, + seq_lens, + seq_lens, + num_heads, + num_kv_heads, + head_size, + scale, + causal=True, + ), + (query, key, value, output), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, block_size", + ( + (32, 8, 128, 128), + (16, 4, 64, 128), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_flash_attention_decode( + num_heads, + num_kv_heads, + head_size, + block_size, + dtype, + rtol, + atol, + device, +): + """Decode phase: single token per request with paged KV cache.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_reqs = 3 + kv_len = 16 # Total KV length per request. + num_blocks_per_req = (kv_len + block_size - 1) // block_size + num_blocks = num_reqs * num_blocks_per_req + scale = 1.0 / head_size**0.5 + + query = randn_strided( + (num_reqs, num_heads, head_size), None, dtype=dtype, device=device + ) + # Paged KV cache: vLLM standard layout [num_blocks, block_size, KV_N, D]. + kv_cache = randn_strided( + (num_blocks, block_size, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + output = torch.empty((num_reqs, num_heads, head_size), dtype=dtype, device=device) + + # Block table: request i uses blocks [i*num_blocks_per_req, ...]. + block_table = torch.zeros( + (num_reqs, num_blocks_per_req), dtype=torch.int32, device=device + ) + for i in range(num_reqs): + for j in range(num_blocks_per_req): + block_table[i, j] = i * num_blocks_per_req + j + + cu_seqlens_q = torch.arange(0, num_reqs + 1, dtype=torch.int64, device=device) + cu_seqlens_kv = torch.tensor( + [i * kv_len for i in range(num_reqs + 1)], dtype=torch.int64, device=device + ) + + return Payload( + lambda q, k, v, o: _flash_attention( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + True, + -1, + 0, + block_size, + o, + ), + lambda q, k, v, o: _ref_flash_attention_paged( + q, + k, + block_table, + cu_seqlens_q, + cu_seqlens_kv, + num_heads, + num_kv_heads, + head_size, + block_size, + scale, + causal=True, + ), + (query, kv_cache, kv_cache, output), + {}, + rtol=rtol, + atol=atol, + ) + + +def _flash_attention( + query, + key, + value, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + causal, + window_left, + window_right, + block_size, + output, +): + if query.device.type == "npu": + infini.ops.flash_attention( + query, + key, + value, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + causal, + window_left, + window_right, + block_size, + output, + stream=get_npu_stream(query), + ) + else: + infini.ops.flash_attention( + query, + key, + value, + cu_seqlens_q, + cu_seqlens_kv, + block_table, + num_heads, + num_kv_heads, + head_size, + scale, + causal, + window_left, + window_right, + block_size, + output, + ) + + return output + + +def _ref_flash_attention( + query, key, value, num_heads, num_kv_heads, head_size, scale, causal=True +): + """PyTorch SDPA reference for single-sequence prefill.""" + # [T, N, D] -> [N, T, D] + q = query.transpose(0, 1).float() + k = key.transpose(0, 1).float() + v = value.transpose(0, 1).float() + + # GQA: expand K/V to match num_heads. + if num_kv_heads < num_heads: + ratio = num_heads // num_kv_heads + k = k.repeat_interleave(ratio, dim=0) + v = v.repeat_interleave(ratio, dim=0) + + # [N, T, D] -> [1, N, T, D] for scaled_dot_product_attention. + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + + out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, scale=scale, is_causal=causal + ) + + # [1, N, T, D] -> [T, N, D] -> original dtype. + return out.squeeze(0).transpose(0, 1).to(query.dtype) + + +def _ref_flash_attention_multi( + query, + key, + value, + seq_lens_q, + seq_lens_kv, + num_heads, + num_kv_heads, + head_size, + scale, + causal=True, +): + """PyTorch SDPA reference for multi-sequence prefill.""" + outputs = [] + offset = 0 + for sq, sk in zip(seq_lens_q, seq_lens_kv): + q = query[offset : offset + sq] + k = key[offset : offset + sq] + v = value[offset : offset + sq] + out = _ref_flash_attention( + q, k, v, num_heads, num_kv_heads, head_size, scale, causal + ) + outputs.append(out) + offset += sq + + return torch.cat(outputs, dim=0) + + +def _ref_flash_attention_paged( + query, + kv_cache_arg, + block_table, + cu_seqlens_q, + cu_seqlens_kv, + num_heads, + num_kv_heads, + head_size, + block_size, + scale, + causal=True, +): + """PyTorch SDPA reference for decode with paged KV cache.""" + cu_kv = cu_seqlens_kv.cpu() + bt = block_table.cpu() + cache = kv_cache_arg.cpu() + q_cpu = query.cpu() + num_reqs = bt.size(0) + outputs = [] + + for i in range(num_reqs): + q = q_cpu[i : i + 1] # [1, N, D] + kv_len = int(cu_kv[i + 1] - cu_kv[i]) + + # Gather KV from paged cache. + # cache: [num_blocks, KV_N, block_size, D] + blocks = bt[i] + k_pages = [] + v_pages = [] + remaining = kv_len + for b in blocks: + if remaining <= 0: + break + take = min(remaining, block_size) + # cache layout: [num_blocks, block_size, KV_N, D] + # Slice [take, KV_N, D], transpose to [KV_N, take, D] for cat. + k_pages.append(cache[int(b.item()), :take, :, :].transpose(0, 1)) + v_pages.append(cache[int(b.item()), :take, :, :].transpose(0, 1)) + remaining -= take + k = torch.cat(k_pages, dim=1) # [KV_N, kv_len, D] + v = torch.cat(v_pages, dim=1) + + # Decode: Q_S=1 attends to all past KV positions; causal masking is + # not applicable here (it would mask everything beyond position 0). + out = _ref_flash_attention( + q, # [1, N, D] - already TND format + k.transpose(0, 1), # [KV_N, kv_len, D] -> [kv_len, KV_N, D] + v.transpose(0, 1), + num_heads, + num_kv_heads, + head_size, + scale, + causal=False, + ) + outputs.append(out) + + return torch.cat(outputs, dim=0).to(query.device) diff --git a/tests/test_gemm.py b/tests/test_gemm.py index af8b44f..3f48562 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -86,7 +86,13 @@ def test_gemm( def _gemm(a, b, alpha, beta, trans_a, trans_b, c, implementation_index=0): if a.device.type == "npu": infini.ops.gemm( - a, b, alpha, beta, trans_a, trans_b, c, + a, + b, + alpha, + beta, + trans_a, + trans_b, + c, stream=get_npu_stream(a), ) else: diff --git a/tests/test_reshape_and_cache.py b/tests/test_reshape_and_cache.py new file mode 100644 index 0000000..813afc3 --- /dev/null +++ b/tests/test_reshape_and_cache.py @@ -0,0 +1,152 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, get_npu_stream, randn_strided + +# ReshapeAndCache only works on NPU (aclrtMemcpy-based), so tests only +# parametrize on float16/bfloat16 and use explicit device parametrization. + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_tokens, num_kv_heads, head_size, num_blocks, block_size", + ( + (1, 8, 128, 4, 16), + (4, 8, 128, 4, 16), + (8, 4, 64, 8, 32), + (16, 2, 128, 8, 64), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_reshape_and_cache_contiguous( + num_tokens, + num_kv_heads, + head_size, + num_blocks, + block_size, + dtype, + rtol, + atol, + device, +): + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + # Layout: [2, num_blocks, block_size, num_kv_heads, head_size] + # Index 0 = key cache, index 1 = value cache. + kv_cache = torch.zeros( + (2, num_blocks, block_size, num_kv_heads, head_size), + dtype=dtype, + device=device, + ) + # Contiguous slot mapping: token i -> slot i. + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) + + return Payload( + _reshape_and_cache, + _ref_reshape_and_cache, + (key, value, kv_cache, slot_mapping, kv_cache), + {}, + rtol=rtol, + atol=atol, + ) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "num_tokens, num_kv_heads, head_size, num_blocks, block_size", + ( + (4, 8, 128, 4, 16), + (8, 4, 64, 8, 32), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_reshape_and_cache_noncontiguous_slots( + num_tokens, + num_kv_heads, + head_size, + num_blocks, + block_size, + dtype, + rtol, + atol, + device, +): + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + key = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + value = randn_strided( + (num_tokens, num_kv_heads, head_size), None, dtype=dtype, device=device + ) + kv_cache = torch.zeros( + (2, num_blocks, block_size, num_kv_heads, head_size), + dtype=dtype, + device=device, + ) + # Non-contiguous slots: skip every other slot. + slot_mapping = torch.tensor( + [i * 2 for i in range(num_tokens)], dtype=torch.int64, device=device + ) + + return Payload( + _reshape_and_cache, + _ref_reshape_and_cache, + (key, value, kv_cache, slot_mapping, kv_cache), + {}, + rtol=rtol, + atol=atol, + ) + + +def _reshape_and_cache(key, value, kv_cache, slot_mapping, kv_cache_out): + if key.device.type == "npu": + infini.ops.reshape_and_cache( + key, value, kv_cache, slot_mapping, kv_cache_out, stream=get_npu_stream(key) + ) + else: + infini.ops.reshape_and_cache(key, value, kv_cache, slot_mapping, kv_cache_out) + + return kv_cache_out + + +def _ref_reshape_and_cache(key, value, kv_cache, slot_mapping, kv_cache_out): + kv_cache_out = kv_cache_out.clone() + slots = slot_mapping.cpu() + block_size = kv_cache_out.size(2) + + for i in range(key.size(0)): + slot = int(slots[i].item()) + + if slot < 0: + continue + + block_idx = slot // block_size + offset = slot % block_size + kv_cache_out[0, block_idx, offset, :, :] = key[i, :, :] + kv_cache_out[1, block_idx, offset, :, :] = value[i, :, :] + + return kv_cache_out diff --git a/tests/test_rms_norm.py b/tests/test_rms_norm.py index d6d4dff..ba540a9 100644 --- a/tests/test_rms_norm.py +++ b/tests/test_rms_norm.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, empty_strided, randn_strided +from tests.utils import Payload, empty_strided, get_npu_stream, randn_strided @pytest.mark.auto_act_and_assert @@ -53,7 +53,10 @@ def test_rms_norm( def _rms_norm(input, weight, *, eps=1e-6, out=None): - infini.ops.rms_norm(input, weight, eps, out) + if input.device.type == "npu": + infini.ops.rms_norm(input, weight, eps, out, stream=get_npu_stream(input)) + else: + infini.ops.rms_norm(input, weight, eps, out) return out diff --git a/tests/test_rotary_embedding.py b/tests/test_rotary_embedding.py new file mode 100644 index 0000000..733ae43 --- /dev/null +++ b/tests/test_rotary_embedding.py @@ -0,0 +1,266 @@ +import infini.ops +import pytest +import torch + +from tests.utils import get_npu_stream, randn_strided, randint_strided + + +def _rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + query_out, + key_out, + device, +): + if device == "npu": + infini.ops.rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + query_out, + key_out, + stream=get_npu_stream(query), + ) + else: + infini.ops.rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + query_out, + key_out, + ) + + return query_out, key_out + + +def _ref_rotary_embedding( + positions, query, key, cos_sin_cache, head_size, rotary_dim, is_neox_style +): + """PyTorch reference for RoPE. + + ``cos_sin_cache`` layout: ``[max_seq_len, rotary_dim]`` where the first + ``rotary_dim // 2`` columns are cos and the rest are sin. + """ + T = query.size(0) + R = rotary_dim + half_R = R // 2 + + cos_sin = cos_sin_cache.float() + cos_half = cos_sin[:, :half_R] + sin_half = cos_sin[:, half_R:] + + def apply_rope(x): + out = x.float().clone() + + for t in range(T): + p = positions[t].item() + c = cos_half[p] + s = sin_half[p] + + if is_neox_style: + x1 = x[t, :, :half_R].float() + x2 = x[t, :, half_R:R].float() + out[t, :, :half_R] = c * x1 - s * x2 + out[t, :, half_R:R] = c * x2 + s * x1 + else: + x1 = x[t, :, 0::2].float() + x2 = x[t, :, 1::2].float() + out[t, :, 0::2] = c * x1 - s * x2 + out[t, :, 1::2] = c * x2 + s * x1 + + return out.to(x.dtype) + + return apply_rope(query), apply_rope(key) + + +def _assert_close(actual, expected, rtol, atol): + assert torch.allclose(actual, expected, rtol=rtol, atol=atol), ( + f"Max diff: {(actual.float() - expected.float()).abs().max().item()}" + ) + + +@pytest.mark.parametrize( + "num_heads, head_size", + ( + (32, 128), + (8, 64), + ), +) +@pytest.mark.parametrize("is_neox_style", (True, False)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_rotary_embedding_full( + num_heads, head_size, is_neox_style, dtype, rtol, atol, device +): + """Full rotary: ``rotary_dim == head_size``.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_kv_heads = num_heads + rotary_dim = head_size + num_tokens = 16 + max_seq_len = 64 + + positions = randint_strided( + 0, + max_seq_len, + (num_tokens,), + None, + dtype=torch.int64, + device=device, + ) + query = randn_strided( + (num_tokens, num_heads, head_size), + None, + dtype=dtype, + device=device, + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + cos_sin_cache = randn_strided( + (max_seq_len, rotary_dim), + None, + dtype=dtype, + device=device, + ) + query_out = torch.empty_like(query) + key_out = torch.empty_like(key) + + q_out, k_out = _rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + query_out, + key_out, + device, + ) + + ref_q, ref_k = _ref_rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + ) + + _assert_close(q_out, ref_q, rtol, atol) + _assert_close(k_out, ref_k, rtol, atol) + + +@pytest.mark.parametrize( + "num_heads, num_kv_heads, head_size, rotary_dim", + ( + (32, 8, 128, 64), + (16, 4, 64, 32), + ), +) +@pytest.mark.parametrize("is_neox_style", (True,)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +@pytest.mark.parametrize("device", ("npu",)) +def test_rotary_embedding_partial( + num_heads, + num_kv_heads, + head_size, + rotary_dim, + is_neox_style, + dtype, + rtol, + atol, + device, +): + """Partial rotary: ``rotary_dim < head_size``.""" + if device == "npu" and not (hasattr(torch, "npu") and torch.npu.is_available()): + pytest.skip("NPU not available") + + num_tokens = 16 + max_seq_len = 64 + + positions = randint_strided( + 0, + max_seq_len, + (num_tokens,), + None, + dtype=torch.int64, + device=device, + ) + query = randn_strided( + (num_tokens, num_heads, head_size), + None, + dtype=dtype, + device=device, + ) + key = randn_strided( + (num_tokens, num_kv_heads, head_size), + None, + dtype=dtype, + device=device, + ) + cos_sin_cache = randn_strided( + (max_seq_len, rotary_dim), + None, + dtype=dtype, + device=device, + ) + query_out = torch.empty_like(query) + key_out = torch.empty_like(key) + + q_out, k_out = _rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + query_out, + key_out, + device, + ) + + ref_q, ref_k = _ref_rotary_embedding( + positions, + query, + key, + cos_sin_cache, + head_size, + rotary_dim, + is_neox_style, + ) + + _assert_close(q_out, ref_q, rtol, atol) + _assert_close(k_out, ref_k, rtol, atol) diff --git a/tests/test_swiglu.py b/tests/test_swiglu.py index 89c95f7..71eaceb 100644 --- a/tests/test_swiglu.py +++ b/tests/test_swiglu.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, empty_strided, rand_strided +from tests.utils import Payload, empty_strided, get_npu_stream, rand_strided @pytest.mark.auto_act_and_assert @@ -38,7 +38,10 @@ def test_swiglu( def _swiglu(input, gate, out): - infini.ops.swiglu(input, gate, out) + if input.device.type == "npu": + infini.ops.swiglu(input, gate, out, stream=get_npu_stream(input)) + else: + infini.ops.swiglu(input, gate, out) return out